// Package commands provides CLI commands for the admin tool
package commands
import (
"context"
"database/sql"
"os"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/spf13/cobra"
)
// DatabaseCommands returns the database management commands
func DatabaseCommands(userService *services.UserService, logger *observability.Logger, db *sql.DB) *cobra.Command {
dbCmd := &cobra.Command{
Use: "db",
Short: "Database management commands",
Long: `Database management commands for the quiz application.
Available commands:
stats - Show database statistics
cleanup - Run database cleanup operations`,
}
// Add subcommands
dbCmd.AddCommand(statsCmd(userService, logger, db))
dbCmd.AddCommand(cleanupCmd(logger, db))
return dbCmd
}
// statsCmd returns the stats command
func statsCmd(userService *services.UserService, logger *observability.Logger, db *sql.DB) *cobra.Command {
return &cobra.Command{
Use: "stats",
Short: "Show database statistics",
Long: `Show database statistics including user counts and other metrics.`,
RunE: runStats(userService, logger, db),
}
}
// cleanupCmd returns the cleanup command
func cleanupCmd(logger *observability.Logger, db *sql.DB) *cobra.Command {
var statsOnly bool
cmd := &cobra.Command{
Use: "cleanup",
Short: "Run database cleanup operations",
Long: `Run database cleanup operations to remove old data.
This command will:
- Remove questions with legacy question types
- Remove orphaned user responses
Use --stats flag to see what would be cleaned up without actually performing the cleanup.`,
RunE: runCleanup(logger, &statsOnly, db),
}
// Add flags
cmd.Flags().BoolVar(&statsOnly, "stats", false, "Only show cleanup statistics, don't perform cleanup")
return cmd
}
// runStats returns a function that shows database statistics
func runStats(userService *services.UserService, logger *observability.Logger, db *sql.DB) func(cmd *cobra.Command, args []string) error {
return func(_ *cobra.Command, _ []string) error {
ctx := context.Background()
// Log diagnostic information
logger.Info(ctx, "Diagnostic info", map[string]interface{}{"config_file": os.Getenv("QUIZ_CONFIG_FILE"), "database": getDatabaseInfo(db)})
logger.Info(ctx, "Showing database statistics", map[string]interface{}{})
// Get user statistics
users, err := userService.GetAllUsers(ctx)
if err != nil {
logger.Error(ctx, "Failed to get user statistics", err, map[string]interface{}{})
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to get user statistics: %w", err)
}
logger.Info(ctx, "Database statistics", map[string]interface{}{"total_users": len(users), "database": "PostgreSQL", "status": "Connected"})
return nil
}
}
// runCleanup returns a function that runs database cleanup
func runCleanup(logger *observability.Logger, statsOnly *bool, db *sql.DB) func(cmd *cobra.Command, args []string) error {
return func(_ *cobra.Command, _ []string) error {
ctx := context.Background()
// Log diagnostic information
logger.Info(ctx, "Diagnostic info", map[string]interface{}{"config_file": os.Getenv("QUIZ_CONFIG_FILE"), "database": getDatabaseInfo(db)})
logger.Info(ctx, "Running database cleanup", map[string]interface{}{"stats_only": *statsOnly})
// Use the database connection passed as parameter
if db == nil {
return contextutils.WrapErrorf(contextutils.ErrInternalError, "database connection not available")
}
// Initialize cleanup service
cleanupService := services.NewCleanupServiceWithLogger(db, logger)
if *statsOnly {
// Show cleanup statistics only
stats, err := cleanupService.GetCleanupStats(ctx)
if err != nil {
logger.Error(ctx, "Failed to get cleanup stats", err, map[string]interface{}{"stats_only": true})
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to get cleanup stats: %w", err)
}
logger.Info(ctx, "Database cleanup statistics", map[string]interface{}{"legacy_questions": stats["legacy_questions"], "orphaned_responses": stats["orphaned_responses"]})
total := stats["legacy_questions"] + stats["orphaned_responses"]
if total == 0 {
logger.Info(ctx, "No cleanup needed - database is clean", map[string]interface{}{"total": total})
} else {
logger.Info(ctx, "Cleanup would remove items", map[string]interface{}{"total": total})
}
return nil
}
// Run full cleanup
logger.Info(ctx, "Starting database cleanup", map[string]interface{}{"service": "cleanup"})
if err := cleanupService.RunFullCleanup(ctx); err != nil {
logger.Error(ctx, "Cleanup failed", err, map[string]interface{}{"service": "cleanup"})
return contextutils.WrapErrorf(contextutils.ErrInternalError, "cleanup failed: %w", err)
}
logger.Info(ctx, "Database cleanup completed successfully", map[string]interface{}{"service": "cleanup"})
return nil
}
}
package commands
import (
"context"
"fmt"
"os"
"syscall"
"golang.org/x/term"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/spf13/cobra"
)
// UserCommands returns the user management commands
func UserCommands(userService *services.UserService, logger *observability.Logger, databaseURL string) *cobra.Command {
userCmd := &cobra.Command{
Use: "user",
Short: "User management commands",
Long: `User management commands for the quiz application.
Available commands:
list - List all users
reset-password - Reset password for a specific user`,
}
// Add subcommands
userCmd.AddCommand(listCmd(userService, logger, databaseURL))
userCmd.AddCommand(resetPasswordCmd(userService, logger))
return userCmd
}
// listCmd returns the list command
func listCmd(userService *services.UserService, logger *observability.Logger, databaseURL string) *cobra.Command {
return &cobra.Command{
Use: "list",
Short: "List all users",
Long: `List all users in the database with their basic information.`,
RunE: runListUsers(userService, logger, databaseURL),
}
}
// resetPasswordCmd returns the reset-password command
func resetPasswordCmd(userService *services.UserService, logger *observability.Logger) *cobra.Command {
return &cobra.Command{
Use: "reset-password [username]",
Short: "Reset password for a user",
Long: `Reset the password for a specific user. If username is not provided, you will be prompted for it.`,
RunE: runResetPassword(userService, logger),
}
}
// runListUsers returns a function that lists all users
func runListUsers(userService *services.UserService, logger *observability.Logger, databaseURL string) func(cmd *cobra.Command, args []string) error {
return func(_ *cobra.Command, _ []string) error {
ctx := context.Background()
// Show diagnostic information
logger.Info(ctx, "Admin command diagnostics", map[string]interface{}{"config_file": os.Getenv("QUIZ_CONFIG_FILE"), "database_url": maskDatabaseURL(databaseURL)})
logger.Info(ctx, "Listing all users", map[string]interface{}{})
users, err := userService.GetAllUsers(ctx)
if err != nil {
logger.Error(ctx, "Failed to get users", err, map[string]interface{}{})
return contextutils.WrapError(err, "failed to get users")
}
if len(users) == 0 {
logger.Info(ctx, "No users found in the database", nil)
return nil
}
// Print header to stdout (user-facing table)
fmt.Printf("%-5s %-20s %-30s %-15s %-10s %-10s %-10s\n", "ID", "Username", "Email", "Language", "Level", "AI Enabled", "Created")
fmt.Println(string(make([]byte, 120))) // Print 120 dashes
// Print each user
for _, user := range users {
aiEnabled := "No"
if user.AIEnabled.Valid && user.AIEnabled.Bool {
aiEnabled = "Yes"
}
email := "N/A"
if user.Email.Valid {
email = user.Email.String
}
language := "N/A"
if user.PreferredLanguage.Valid {
language = user.PreferredLanguage.String
}
level := "N/A"
if user.CurrentLevel.Valid {
level = user.CurrentLevel.String
}
fmt.Printf("%-5d %-20s %-30s %-15s %-10s %-10s %-10s\n",
user.ID,
user.Username,
email,
language,
level,
aiEnabled,
user.CreatedAt.Format("2006-01-02"),
)
}
logger.Info(ctx, "Listed users", map[string]interface{}{"total": len(users)})
return nil
}
}
// runResetPassword returns a function that resets a user's password
func runResetPassword(userService *services.UserService, logger *observability.Logger) func(cmd *cobra.Command, args []string) error {
return func(_ *cobra.Command, args []string) error {
ctx := context.Background()
var username string
var newPassword string
// Get username from args or prompt
if len(args) > 0 {
username = args[0]
} else {
fmt.Print("Enter username: ")
if _, err := fmt.Scanln(&username); err != nil {
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to read username: %w", err)
}
}
if username == "" {
return contextutils.ErrorWithContextf("username is required")
}
// Prompt for password securely
fmt.Print("Enter new password: ")
passwordBytes, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to read password: %w", err)
}
newPassword = string(passwordBytes)
fmt.Println() // New line after password input
if newPassword == "" {
return contextutils.ErrorWithContextf("password cannot be empty")
}
// Confirm password
fmt.Print("Confirm new password: ")
confirmBytes, err := term.ReadPassword(int(syscall.Stdin))
if err != nil {
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to read password confirmation: %w", err)
}
confirmPassword := string(confirmBytes)
fmt.Println() // New line after password input
if newPassword != confirmPassword {
return contextutils.ErrorWithContextf("passwords do not match")
}
logger.Info(ctx, "Resetting password for user", map[string]interface{}{
"username": username,
})
// Get user by username
user, err := userService.GetUserByUsername(ctx, username)
if err != nil {
logger.Error(ctx, "Failed to get user", err, map[string]interface{}{"username": username})
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to get user '%s': %w", username, err)
}
if user == nil {
logger.Error(ctx, "User not found", nil, map[string]interface{}{"username": username})
return contextutils.ErrorWithContextf("user '%s' not found", username)
}
// Update the password
err = userService.UpdateUserPassword(ctx, user.ID, newPassword)
if err != nil {
logger.Error(ctx, "Failed to update password", err, map[string]interface{}{
"username": username,
"user_id": user.ID,
})
return contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to update password for user '%s': %w", username, err)
}
fmt.Printf("â Password successfully reset for user '%s' (ID: %d)\n", username, user.ID)
logger.Info(ctx, "Password reset successful", map[string]interface{}{
"username": username,
"user_id": user.ID,
})
return nil
}
}
package commands
import (
"database/sql"
"fmt"
"strings"
)
// maskDatabaseURL masks sensitive parts of the database URL for display
func maskDatabaseURL(url string) string {
// Simple masking for display purposes
if strings.Contains(url, "@") {
parts := strings.Split(url, "@")
if len(parts) == 2 {
return "postgres://***:***@" + parts[1]
}
}
return url
}
// getDatabaseInfo returns database connection information
func getDatabaseInfo(db *sql.DB) string {
if db == nil {
return "Not connected"
}
// Try to get database name
var dbName string
err := db.QueryRow("SELECT current_database()").Scan(&dbName)
if err != nil {
return "Connected (unknown database)"
}
// Try to get host information
var host string
err = db.QueryRow("SELECT inet_server_addr()::text").Scan(&host)
if err != nil {
return fmt.Sprintf("Connected to %s", dbName)
}
return fmt.Sprintf("Connected to %s on %s", dbName, host)
}
// Package main provides the main entry point for the quiz application admin CLI tool.
package main
import (
"context"
"fmt"
"os"
"quizapp/cmd/adm/commands"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/observability"
"quizapp/internal/services"
"github.com/spf13/cobra"
)
// Global variables for shared resources
var (
cfg *config.Config
logger *observability.Logger
userService *services.UserService
)
func main() {
ctx := context.Background()
// Set default config file if not already set
if os.Getenv("QUIZ_CONFIG_FILE") == "" {
// Try to find the config file in common locations
defaultPaths := []string{
"../merged.config.yaml", // From backend/cmd/adm/
"../../merged.config.yaml", // From backend/cmd/adm/ (alternative)
"merged.config.yaml", // Current directory
}
for _, path := range defaultPaths {
if _, err := os.Stat(path); err == nil {
if err := os.Setenv("QUIZ_CONFIG_FILE", path); err != nil {
fmt.Fprintf(os.Stderr, "Failed to set QUIZ_CONFIG_FILE environment variable: %v\n", err)
os.Exit(1)
}
break
}
}
}
// Load configuration
var err error
cfg, err = config.NewConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load configuration: %v\n", err)
os.Exit(1)
}
// Override log level for admin tool
cfg.Server.LogLevel = "error"
// Disable all OpenTelemetry features for admin CLI to avoid connection errors
cfg.OpenTelemetry.EnableTracing = false
cfg.OpenTelemetry.EnableMetrics = false
cfg.OpenTelemetry.EnableLogging = false
// Setup observability (tracing/metrics/logging)
tp, mp, loggerInstance, err := observability.SetupObservability(&cfg.OpenTelemetry, "quiz-admin")
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to initialize observability: %v\n", err)
os.Exit(1)
}
// Store logger globally
logger = loggerInstance
// Defer cleanup
defer func() {
if tp != nil {
if err := tp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down tracer provider", map[string]interface{}{"error": err.Error(), "provider": "tracer"})
}
}
if mp != nil {
if err := mp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down meter provider", map[string]interface{}{"error": err.Error(), "provider": "meter"})
}
}
}()
// Initialize database manager
dbManager := database.NewManager(logger)
// Initialize database connection with configuration (no migrations for admin tool)
db, err := dbManager.InitDBWithoutMigrations(cfg.Database)
if err != nil {
logger.Error(ctx, "Failed to connect to database", err, map[string]interface{}{"db_url": cfg.Database.URL})
os.Exit(1)
}
defer func() {
if err := db.Close(); err != nil {
logger.Warn(ctx, "Warning: failed to close database connection", map[string]interface{}{"error": err.Error(), "db_url": cfg.Database.URL})
}
}()
// Initialize services
userService = services.NewUserServiceWithLogger(db, cfg, logger)
// Create the root command
rootCmd := &cobra.Command{
Use: "adm",
Short: "Quiz Application Administration Tool",
Long: `Quiz Application Administration Tool
A comprehensive CLI tool for administering the quiz application.
Provides commands for user management, database operations, and system administration.`,
Run: func(cmd *cobra.Command, _ []string) {
// Show help if no subcommand provided
if err := cmd.Help(); err != nil {
fmt.Printf("Error showing help: %v\n", err)
}
},
}
// Add subcommands with initialized services
rootCmd.AddCommand(commands.UserCommands(userService, logger, cfg.Database.URL))
rootCmd.AddCommand(commands.DatabaseCommands(userService, logger, db))
// Execute the command
if err := rootCmd.Execute(); err != nil {
os.Exit(1)
}
}
// Package main provides a CLI tool for running the worker to generate questions for a specific user.
package main
import (
"context"
"flag"
"fmt"
"os"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
"quizapp/internal/worker"
)
func main() {
ctx := context.Background()
// Define command line flags
var (
username = flag.String("username", "", "Username to generate questions for (required)")
level = flag.String("level", "", "Override user's current level (optional)")
language = flag.String("language", "", "Override user's preferred language (optional)")
questionType = flag.String("type", "vocabulary", "Question type: vocabulary, fill_blank, qa, reading_comprehension")
topic = flag.String("topic", "", "Specific topic for questions (optional)")
count = flag.Int("count", 5, "Number of questions to generate")
aiProvider = flag.String("ai-provider", "", "Override AI provider (optional)")
aiModel = flag.String("ai-model", "", "Override AI model (optional)")
aiAPIKey = flag.String("ai-api-key", "", "Override AI API key (optional)")
help = flag.Bool("help", false, "Show help message")
)
flag.Parse()
if *help {
printUsage(nil)
return
}
if *username == "" {
fmt.Fprintln(os.Stderr, "Error: --username flag is required")
os.Exit(1)
}
// Load configuration
cfg, err := config.NewConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load configuration: %v\n", err)
os.Exit(1)
}
// Setup observability (tracing/metrics/logging)
tp, mp, logger, err := observability.SetupObservability(&cfg.OpenTelemetry, "quiz-cli-worker")
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to initialize observability: %v\n", err)
os.Exit(1)
}
defer func() {
if tp != nil {
if err := tp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down tracer provider", map[string]interface{}{"error": err.Error()})
}
}
if mp != nil {
if err := mp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down meter provider", map[string]interface{}{"error": err.Error()})
}
}
}()
logger.Info(ctx, "Starting quiz CLI worker", map[string]interface{}{
"username": *username,
"question_type": *questionType,
"count": *count,
})
// Validate question type
validTypes := map[string]models.QuestionType{
"vocabulary": models.Vocabulary,
"fill_blank": models.FillInBlank,
"qa": models.QuestionAnswer,
"reading_comprehension": models.ReadingComprehension,
}
qType, valid := validTypes[strings.ToLower(*questionType)]
if !valid {
logger.Error(ctx, "Invalid question type", nil, map[string]interface{}{"question_type": *questionType})
fmt.Fprintf(os.Stderr, "Error: Invalid question type '%s'\n", *questionType)
os.Exit(1)
}
// Validate level if provided
if *level != "" {
if !isValidLevel(*level, cfg.GetAllLevels()) {
logger.Error(ctx, "Invalid level", nil, map[string]interface{}{"level": *level})
fmt.Fprintf(os.Stderr, "Error: Invalid level '%s'\n", *level)
os.Exit(1)
}
}
// Validate language if provided (use dynamic list from config)
validLanguages := cfg.GetLanguages()
if *language != "" {
if !isValidLanguage(*language, validLanguages) {
logger.Error(ctx, "Invalid language", nil, map[string]interface{}{"language": *language})
fmt.Fprintf(os.Stderr, "Error: Invalid language '%s'\n", *language)
os.Exit(1)
}
}
// Initialize database manager with logger
dbManager := database.NewManager(logger)
// Initialize database connection with configuration
db, err := dbManager.InitDBWithoutMigrations(cfg.Database)
if err != nil {
logger.Error(ctx, "Failed to connect to database", err, map[string]interface{}{"db_url": cfg.Database.URL})
fmt.Fprintf(os.Stderr, "Failed to connect to database: %v\n", err)
os.Exit(1)
}
defer func() {
if err := db.Close(); err != nil {
logger.Warn(ctx, "Warning: failed to close database connection", map[string]interface{}{"error": err.Error(), "db_url": cfg.Database.URL})
}
}()
// Initialize services
userService := services.NewUserServiceWithLogger(db, cfg, logger)
learningService := services.NewLearningServiceWithLogger(db, cfg, logger)
// Create question service
questionService := services.NewQuestionServiceWithLogger(db, learningService, cfg, logger)
aiService := services.NewAIService(cfg, logger)
workerService := services.NewWorkerServiceWithLogger(db, logger)
// Get user by username
user, err := userService.GetUserByUsername(ctx, *username)
if err != nil {
logger.Error(ctx, "Failed to get user", err)
fmt.Fprintf(os.Stderr, "Failed to get user: %v\n", err)
os.Exit(1)
}
if user == nil {
logger.Error(ctx, "User not found", nil, map[string]interface{}{"username": *username})
fmt.Fprintf(os.Stderr, "User not found: %s\n", *username)
os.Exit(1)
return
}
logger.Info(ctx, "Found user", map[string]interface{}{"username": user.Username, "user_id": user.ID})
// Apply AI overrides if provided
if *aiProvider != "" {
user.AIProvider.String = *aiProvider
user.AIProvider.Valid = true
user.AIEnabled.Bool = true
user.AIEnabled.Valid = true
}
if *aiModel != "" {
user.AIModel.String = *aiModel
user.AIModel.Valid = true
}
if *aiAPIKey != "" {
// Set AI provider and API key if provided
if *aiProvider != "" && *aiAPIKey != "" {
if err := userService.SetUserAPIKey(ctx, user.ID, *aiProvider, *aiAPIKey); err != nil {
logger.Error(ctx, "Failed to set API key", err)
fmt.Fprintf(os.Stderr, "Failed to set API key: %v\n", err)
os.Exit(1)
}
} else if *aiAPIKey != "" {
// If only API key is provided, use the user's current AI provider
if err := userService.SetUserAPIKey(ctx, user.ID, user.AIProvider.String, *aiAPIKey); err != nil {
logger.Error(ctx, "Failed to set API key", err)
fmt.Fprintf(os.Stderr, "Failed to set API key: %v\n", err)
os.Exit(1)
}
}
}
// Check if user has AI enabled (after potential overrides)
if !user.AIEnabled.Valid || !user.AIEnabled.Bool {
logger.Warn(ctx, "User does not have AI enabled", map[string]interface{}{"username": user.Username, "user_id": user.ID})
logger.Info(ctx, "You may want to enable AI for this user first or use --ai-provider flag")
}
// Determine language and level to use
languageToUse := user.PreferredLanguage.String
if *language != "" {
languageToUse = *language
}
levelToUse := user.CurrentLevel.String
if *level != "" {
levelToUse = *level
}
// Validate that we have required settings
if languageToUse == "" {
logger.Error(ctx, "No language specified", nil, map[string]interface{}{"username": user.Username, "user_id": user.ID})
fmt.Fprintln(os.Stderr, "Error: No language specified. User has no preferred language and --language flag not provided")
os.Exit(1)
}
if levelToUse == "" {
logger.Error(ctx, "No level specified", nil, map[string]interface{}{"username": user.Username, "user_id": user.ID})
fmt.Fprintln(os.Stderr, "Error: No level specified. User has no current level and --level flag not provided")
os.Exit(1)
}
// Print configuration
fmt.Printf("=== CLI Worker Configuration ===\n")
fmt.Printf("User: %s (ID: %d)\n", user.Username, user.ID)
fmt.Printf("Language: %s\n", languageToUse)
fmt.Printf("Level: %s\n", levelToUse)
fmt.Printf("Question Type: %s\n", qType)
fmt.Printf("Count: %d\n", *count)
if *topic != "" {
fmt.Printf("Topic: %s\n", *topic)
}
if user.AIProvider.Valid && user.AIProvider.String != "" {
fmt.Printf("AI Provider: %s\n", user.AIProvider.String)
}
if user.AIModel.Valid && user.AIModel.String != "" {
fmt.Printf("AI Model: %s\n", user.AIModel.String)
}
fmt.Printf("===============================\n\n")
// Create email service
emailService := services.CreateEmailService(cfg, logger)
// Create daily question service
dailyQuestionService := services.NewDailyQuestionService(db, logger, questionService, learningService)
// Create a minimal worker instance for question generation
workerInstance := worker.NewWorker(userService, questionService, aiService, learningService, workerService, dailyQuestionService, emailService, nil, "cli", cfg, logger)
// Create context with timeout
ctx, cancel := context.WithTimeout(ctx, config.CLIWorkerTimeout)
defer cancel()
// Log CLI worker start with structured logging
logger.Info(ctx, "CLI worker starting question generation", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"question_type": qType,
"count": *count,
"language": languageToUse,
"level": levelToUse,
})
// Generate questions
fmt.Printf("Starting question generation...\n")
startTime := time.Now()
result, err := workerInstance.GenerateQuestionsForUser(ctx, user, languageToUse, levelToUse, qType, *count, *topic)
duration := time.Since(startTime)
if err != nil {
fmt.Printf("\nâ Question generation failed after %v\n", duration)
fmt.Printf("Error: %v\n", err)
os.Exit(1)
}
fmt.Printf("\nâ Question generation completed successfully in %v\n", duration)
fmt.Printf("Result: %s\n", result)
}
func isValidLevel(level string, validLevels []string) bool {
for _, validLevel := range validLevels {
if strings.EqualFold(level, validLevel) {
return true
}
}
return false
}
func isValidLanguage(language string, validLanguages []string) bool {
for _, validLang := range validLanguages {
if strings.EqualFold(language, validLang) {
return true
}
}
return false
}
func printUsage(cfg *config.Config) {
if cfg == nil {
fmt.Fprintf(os.Stderr, "Error: Configuration is missing or invalid.\n")
return
}
fmt.Printf("Usage: cli-worker [flags]\n")
fmt.Printf("Flags:\n")
fmt.Printf(" -language string\tLanguage to generate questions for\n")
fmt.Printf(" -level string\tLevel to generate questions for\n")
fmt.Printf(" -type string\tQuestion type (vocabulary, fill_in_blank, qa, reading_comprehension)\n")
fmt.Printf(" -count int\tNumber of questions to generate (default 1)\n")
fmt.Printf(" -topic string\tTopic for question generation\n")
fmt.Printf(" -provider string\tAI provider to use\n")
fmt.Printf(" -model string\tAI model to use\n")
fmt.Printf(" -help\tShow this help message\n\n")
fmt.Printf("Valid levels: %s\n", strings.Join(cfg.GetAllLevels(), ", "))
fmt.Printf("Valid languages: %s\n", strings.Join(cfg.GetLanguages(), ", "))
if cfg.Providers != nil {
providerNames := make([]string, 0, len(cfg.Providers))
for _, p := range cfg.Providers {
providerNames = append(providerNames, p.Code)
}
fmt.Printf("Valid providers: %s\n", strings.Join(providerNames, ", "))
} else {
fmt.Printf("Valid providers: \n")
}
}
// Package main provides a small CLI utility to reset the application's
// database to a clean state. It is intended for local development and
// testing only and will permanently delete all data when run.
package main
import (
"bufio"
"context"
"fmt"
"os"
"strings"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/observability"
"quizapp/internal/services"
)
// fatalIfErr logs the error with context and exits
func fatalIfErr(ctx context.Context, logger *observability.Logger, msg string, err error, fields map[string]interface{}) {
logger.Error(ctx, msg, err, fields)
os.Exit(1)
}
func main() {
ctx := context.Background()
// Load configuration first
cfg, err := config.NewConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load configuration: %v\n", err)
os.Exit(1)
}
// Setup observability (tracing/metrics/logging)
tp, mp, logger, err := observability.SetupObservability(&cfg.OpenTelemetry, "reset-db")
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to initialize observability: %v\n", err)
os.Exit(1)
}
defer func() {
if tp != nil {
if err := tp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down tracer provider", map[string]interface{}{"error": err.Error(), "provider": "tracer"})
}
}
if mp != nil {
if err := mp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down meter provider", map[string]interface{}{"error": err.Error(), "provider": "meter"})
}
}
}()
fmt.Println("âï DATABASE RESET UTILITY âï")
fmt.Println("=============================")
fmt.Println("This will PERMANENTLY DELETE ALL DATA in the database!")
fmt.Println("This includes:")
fmt.Println("- All users (including admin)")
fmt.Println("- All questions")
fmt.Println("- All user responses")
fmt.Println("- All performance metrics")
fmt.Println("")
logger.Info(ctx, "Attempting to reset the database", map[string]interface{}{"service": "reset-db"})
if cfg.Database.URL == "" {
fatalIfErr(ctx, logger, "Database URL is empty", nil, map[string]interface{}{"error": "Database URL is empty. Cannot proceed with reset."})
}
// Print database info
fmt.Println("ð Database Information:")
fmt.Printf("URL: %s\n", maskDatabaseURL(cfg.Database.URL))
fmt.Println("")
// Confirm with user
if !confirmReset() {
fmt.Println("Reset cancelled.")
return
}
// Initialize database manager with logger
dbManager := database.NewManager(logger)
// Initialize database connection with configuration
db, err := dbManager.InitDBWithConfig(cfg.Database)
if err != nil {
fatalIfErr(ctx, logger, "Failed to connect to database", err, map[string]interface{}{"db_url": cfg.Database.URL})
}
defer func() {
if err := db.Close(); err != nil {
logger.Warn(ctx, "Warning: failed to close database connection", map[string]interface{}{"error": err.Error(), "db_url": cfg.Database.URL})
}
}()
// Initialize services
userService := services.NewUserServiceWithLogger(db, cfg, logger)
// Drop all tables
fmt.Println("ðï Dropping all tables...")
logger.Info(ctx, "Dropping all tables", map[string]interface{}{"db_url": cfg.Database.URL, "service": "reset-db"})
// For now, we'll just run migrations which will recreate the schema
// In a real implementation, you might want to add a DropAllTables method to the database manager
// Run migrations
fmt.Println("ð Running database migrations...")
logger.Info(ctx, "Running database migrations", map[string]interface{}{"db_url": cfg.Database.URL, "service": "reset-db"})
if err := dbManager.RunMigrations(db); err != nil {
fatalIfErr(ctx, logger, "Failed to run migrations", err, map[string]interface{}{"db_url": cfg.Database.URL})
}
fmt.Println("â Database migrations completed successfully!")
logger.Info(ctx, "Database migrations completed successfully", map[string]interface{}{"db_url": cfg.Database.URL, "service": "reset-db"})
// Recreate admin user immediately
fmt.Printf("Recreating admin user '%s'...\n", cfg.Server.AdminUsername)
logger.Info(ctx, "Recreating admin user", map[string]interface{}{"username": cfg.Server.AdminUsername, "service": "reset-db"})
// Ensure admin user exists
if err := userService.EnsureAdminUserExists(ctx, cfg.Server.AdminUsername, cfg.Server.AdminPassword); err != nil {
fatalIfErr(ctx, logger, "Failed to ensure admin user exists", err, map[string]interface{}{"admin_username": cfg.Server.AdminUsername})
}
fmt.Println("â Admin user recreated successfully!")
logger.Info(ctx, "Admin user recreated successfully", map[string]interface{}{"username": cfg.Server.AdminUsername, "service": "reset-db"})
fmt.Println("")
// Print admin credentials
fmt.Printf("\nAdmin user credentials:\n")
fmt.Printf(" Username: %s\n", cfg.Server.AdminUsername)
fmt.Printf(" Password: %s\n", cfg.Server.AdminPassword)
fmt.Println("")
fmt.Println("â Database is now ready to use!")
fmt.Println("- You can now start the server or use the existing running instance")
fmt.Println("- Use the credentials above to log into the application")
}
func confirmReset() bool {
reader := bufio.NewReader(os.Stdin)
for {
fmt.Print("Are you sure you want to reset the database? (type 'yes' to confirm): ")
response, err := reader.ReadString('\n')
if err != nil {
fmt.Println("Error reading input:", err)
continue
}
response = strings.TrimSpace(strings.ToLower(response))
switch response {
case "yes":
return true
case "no", "":
return false
default:
fmt.Println("Please type 'yes' to confirm or 'no' to cancel.")
}
}
}
func maskDatabaseURL(url string) string {
// Simple masking for display purposes
if strings.Contains(url, "@") {
parts := strings.Split(url, "@")
if len(parts) == 2 {
return "postgres://***:***@" + parts[1]
}
}
return url
}
// Package main provides the main entry point for the quiz application backend server.
// It sets up the HTTP server, database connections, middleware, and API routes.
package main
import (
"context"
"fmt"
"os"
"os/signal"
"syscall"
"time"
"quizapp/internal/config"
"quizapp/internal/di"
"quizapp/internal/handlers"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
)
// Application encapsulates the main application logic and can be tested
type Application struct {
container di.ServiceContainerInterface
router *gin.Engine
}
// NewApplication creates a new application instance
func NewApplication(container di.ServiceContainerInterface) (*Application, error) {
// Get services from container
userService, err := container.GetUserService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get user service")
}
questionService, err := container.GetQuestionService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get question service")
}
learningService, err := container.GetLearningService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get learning service")
}
aiService, err := container.GetAIService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get AI service")
}
workerService, err := container.GetWorkerService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get worker service")
}
dailyQuestionService, err := container.GetDailyQuestionService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get daily question service")
}
oauthService, err := container.GetOAuthService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get OAuth service")
}
generationHintService, err := container.GetGenerationHintService()
if err != nil {
return nil, contextutils.WrapError(err, "failed to get generation hint service")
}
// Use the router factory
router := handlers.NewRouter(
container.GetConfig(),
userService,
questionService,
learningService,
aiService,
workerService,
dailyQuestionService,
oauthService,
generationHintService,
container.GetLogger(),
)
return &Application{
container: container,
router: router,
}, nil
}
// Run starts the application and returns an error if it fails to start
func (a *Application) Run(ctx context.Context, port string) error {
// Start server in a goroutine
serverErr := make(chan error, 1)
go func() {
if err := a.router.Run(":" + port); err != nil {
serverErr <- err
}
}()
// Wait for shutdown signal or server error
select {
case <-ctx.Done():
return nil // Context cancelled, graceful shutdown
case err := <-serverErr:
return contextutils.WrapError(err, "server failed")
}
}
// Shutdown gracefully shuts down the application
func (a *Application) Shutdown(ctx context.Context) error {
return a.container.Shutdown(ctx)
}
func main() {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Setup graceful shutdown
shutdownCh := make(chan os.Signal, 1)
signal.Notify(shutdownCh, syscall.SIGINT, syscall.SIGTERM)
// Load configuration
cfg, err := config.NewConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load configuration: %v\n", err)
os.Exit(1)
}
// Setup observability (tracing/metrics/logging)
tp, mp, logger, err := observability.SetupObservability(&cfg.OpenTelemetry, "quiz-backend")
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to initialize observability: %v\n", err)
os.Exit(1)
}
defer func() {
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 5*time.Second)
defer shutdownCancel()
if tp != nil {
if err := tp.Shutdown(shutdownCtx); err != nil {
logger.Warn(ctx, "Error shutting down tracer provider", map[string]interface{}{"error": err.Error(), "provider": "tracer"})
}
}
if mp != nil {
if err := mp.Shutdown(shutdownCtx); err != nil {
logger.Warn(ctx, "Error shutting down meter provider", map[string]interface{}{"error": err.Error(), "provider": "meter"})
}
}
}()
logger.Info(ctx, "Starting quiz backend service", map[string]interface{}{
"port": cfg.Server.Port,
"logLevel": cfg.Server.LogLevel,
})
// Initialize dependency injection container
container := di.NewServiceContainer(cfg, logger)
// Initialize all services
if err := container.Initialize(ctx); err != nil {
logger.Error(ctx, "Failed to initialize services", err, nil)
os.Exit(1)
}
// Ensure admin user exists
if err := container.EnsureAdminUser(ctx); err != nil {
logger.Error(ctx, "Failed to ensure admin user exists", err, map[string]interface{}{"admin_username": cfg.Server.AdminUsername})
os.Exit(1)
}
// Create application instance
app, err := NewApplication(container)
if err != nil {
logger.Error(ctx, "Failed to create application", err, nil)
os.Exit(1)
}
// Start application in a goroutine
appErr := make(chan error, 1)
go func() {
if err := app.Run(ctx, cfg.Server.Port); err != nil {
appErr <- err
}
}()
// Wait for shutdown signal or application error
select {
case <-shutdownCh:
logger.Info(ctx, "Received shutdown signal, shutting down gracefully", nil)
case err := <-appErr:
logger.Error(ctx, "Application failed", err, nil)
os.Exit(1)
}
// Graceful shutdown
shutdownCtx, shutdownCancel := context.WithTimeout(context.Background(), 30*time.Second)
defer shutdownCancel()
// Shutdown application
if err := app.Shutdown(shutdownCtx); err != nil {
logger.Error(ctx, "Error during application shutdown", err, nil)
os.Exit(1)
}
logger.Info(ctx, "Shutdown completed successfully", nil)
}
// Package main provides a utility to set up the test database with initial data.
package main
import (
"context"
"database/sql"
"encoding/json"
"flag"
"fmt"
"os"
"path/filepath"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"go.uber.org/zap/zapcore"
"gopkg.in/yaml.v3"
)
// TestUser represents a user in the test data files
type TestUser struct {
Username string `yaml:"username"`
Email string `yaml:"email"`
Password string `yaml:"password"` // Special field for password creation
PreferredLanguage string `yaml:"preferred_language"`
CurrentLevel string `yaml:"current_level"`
AIProvider string `yaml:"ai_provider"`
AIModel string `yaml:"ai_model"`
AIAPIKey string `yaml:"ai_api_key"`
Roles []string `yaml:"roles"`
}
// TestUsers represents a collection of test users
type TestUsers struct {
Users []TestUser `yaml:"users"`
}
// TestQuestions represents a collection of test questions
type TestQuestions struct {
Questions []models.Question `yaml:"questions"`
}
// TestResponses represents a collection of test user responses
type TestResponses struct {
UserResponses []struct {
Username string `yaml:"username"`
QuestionIndex int `yaml:"question_index"`
UserAnswer string `yaml:"user_answer"`
IsCorrect bool `yaml:"is_correct"`
ResponseTimeMs int `yaml:"response_time_ms"`
} `yaml:"user_responses"`
QuestionReports []struct {
Username string `yaml:"username"`
QuestionIndex int `yaml:"question_index"`
ReportReason string `yaml:"report_reason"`
CreatedAt *string `yaml:"created_at"`
} `yaml:"question_reports"`
}
// TestAnalytics represents analytics test data
type TestAnalytics struct {
PriorityScores []struct {
Username string `yaml:"username"`
QuestionIndex int `yaml:"question_index"`
PriorityScore float64 `yaml:"priority_score"`
LastCalculatedAt string `yaml:"last_calculated_at"`
} `yaml:"priority_scores"`
LearningPreferences []struct {
Username string `yaml:"username"`
FocusOnWeakAreas bool `yaml:"focus_on_weak_areas"`
FreshQuestionRatio float64 `yaml:"fresh_question_ratio"`
WeakAreaBoost float64 `yaml:"weak_area_boost"`
KnownQuestionPenalty float64 `yaml:"known_question_penalty"`
ReviewIntervalDays int `yaml:"review_interval_days"`
DailyReminderEnabled bool `yaml:"daily_reminder_enabled"`
} `yaml:"learning_preferences"`
PerformanceMetrics []struct {
Username string `yaml:"username"`
Topic string `yaml:"topic"`
Language string `yaml:"language"`
Level string `yaml:"level"`
TotalAttempts int `yaml:"total_attempts"`
CorrectAttempts int `yaml:"correct_attempts"`
AverageResponseTimeMs float64 `yaml:"average_response_time_ms"`
} `yaml:"performance_metrics"`
UserQuestionMetadata []struct {
Username string `yaml:"username"`
QuestionIndex int `yaml:"question_index"`
MarkedAsKnown bool `yaml:"marked_as_known"`
MarkedAsKnownAt *string `yaml:"marked_as_known_at"`
} `yaml:"user_question_metadata"`
}
// TestDailyAssignments represents the structure for daily question assignments in test data
type TestDailyAssignments struct {
DailyAssignments []struct {
Username string `yaml:"username"`
Date string `yaml:"date"`
QuestionIDs []int `yaml:"question_ids"`
CompletedQuestions []int `yaml:"completed_questions"`
} `yaml:"daily_assignments"`
}
func resetTestDatabase(databaseURL, testDB string, logger *observability.Logger) error {
ctx := context.Background()
// Create admin connection string by replacing the database name with 'postgres'
// This connects to the admin database to drop/create the test database
adminConnStr := strings.Replace(databaseURL, "/"+testDB+"?", "/postgres?", 1)
if !strings.Contains(adminConnStr, "/postgres?") {
// Handle case where there's no query string
adminConnStr = strings.Replace(databaseURL, "/"+testDB, "/postgres", 1)
}
logger.Info(ctx, "Connecting to admin database", map[string]interface{}{"connection_string": adminConnStr})
adminDB, err := sql.Open("postgres", adminConnStr)
if err != nil {
return contextutils.WrapErrorf(contextutils.ErrDatabaseConnection, "failed to connect to postgres database for drop/create: %v", err)
}
defer func() {
if err := adminDB.Close(); err != nil {
logger.Warn(ctx, "Warning: failed to close adminDB", map[string]interface{}{"error": err.Error()})
}
}()
logger.Info(ctx, "Terminating connections to test DB", map[string]interface{}{"database": testDB})
_, err = adminDB.Exec(fmt.Sprintf(`
SELECT pg_terminate_backend(pid)
FROM pg_stat_activity
WHERE datname = '%s' AND pid <> pg_backend_pid();
`, testDB))
if err != nil {
logger.Warn(ctx, "Warning: failed to terminate connections", map[string]interface{}{"error": err.Error()})
}
logger.Info(ctx, "Dropping test database", map[string]interface{}{"database": testDB})
_, err = adminDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %s WITH (FORCE);", testDB))
if err != nil {
return contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to drop test database: %v", err)
}
logger.Info(ctx, "Successfully dropped test database", map[string]interface{}{"database": testDB})
logger.Info(ctx, "Creating test database", map[string]interface{}{"database": testDB})
_, err = adminDB.Exec(fmt.Sprintf("CREATE DATABASE %s;", testDB))
if err != nil {
return contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to create test database: %v", err)
}
logger.Info(ctx, "Successfully created test database", map[string]interface{}{"database": testDB})
logger.Info(ctx, "Test database reset complete")
return nil
}
func main() {
ctx := context.Background()
// CLI flags
verbose := flag.Bool("verbose", false, "enable verbose logging")
flag.Parse()
// Load configuration first
cfg, err := config.NewConfig()
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to load config: %v\n", err)
os.Exit(1)
}
// Setup observability (tracing/metrics). Suppress logger creation here to avoid startup noise.
originalLogging := cfg.OpenTelemetry.EnableLogging
cfg.OpenTelemetry.EnableLogging = false
tp, mp, _, err := observability.SetupObservability(&cfg.OpenTelemetry, "setup-test-db")
if err != nil {
fmt.Fprintf(os.Stderr, "Failed to initialize observability: %v\n", err)
os.Exit(1)
}
// Create logger with level based on --verbose flag
logLevel := zapcore.WarnLevel
if *verbose {
logLevel = zapcore.InfoLevel
}
// Restore config flag for logger construction (to allow OTLP exporter if enabled)
cfg.OpenTelemetry.EnableLogging = originalLogging
logger := observability.NewLoggerWithLevel(&cfg.OpenTelemetry, logLevel)
defer func() {
if tp != nil {
if err := tp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down tracer provider", map[string]interface{}{"error": err.Error()})
}
}
if mp != nil {
if err := mp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down meter provider", map[string]interface{}{"error": err.Error()})
}
}
}()
// Get DB connection info from env or use defaults
dbUser := "quiz_user"
dbPassword := "quiz_password"
dbHost := "localhost"
dbPort := "5433"
testDB := "quiz_test_db"
// Allow override from DATABASE_URL
databaseURL := os.Getenv("DATABASE_URL")
if databaseURL == "" {
databaseURL = fmt.Sprintf("postgres://%s:%s@%s:%s/%s?sslmode=disable", dbUser, dbPassword, dbHost, dbPort, testDB)
}
// Debug: Print the DATABASE_URL we're using
logger.Info(ctx, "DATABASE_URL from environment", map[string]interface{}{"database_url": os.Getenv("DATABASE_URL")})
logger.Info(ctx, "Using database URL", map[string]interface{}{"database_url": databaseURL})
// --- Drop and recreate the test database ---
if err := resetTestDatabase(databaseURL, testDB, logger); err != nil {
logger.Error(ctx, "Failed to reset test database", err)
os.Exit(1)
}
// Now connect to the new test database
logger.Info(ctx, "Connecting to database", map[string]interface{}{"database_url": databaseURL})
// Initialize database manager with logger
dbManager := database.NewManager(logger)
db, err := dbManager.InitDB(databaseURL)
if err != nil {
logger.Error(ctx, "Failed to initialize database", err)
os.Exit(1)
}
defer func() {
if err := db.Close(); err != nil {
logger.Warn(ctx, "Warning: failed to close database", map[string]interface{}{"error": err.Error()})
}
}()
// Get the root directory (backend is the working directory)
rootDir, err := os.Getwd()
if err != nil {
logger.Error(ctx, "Failed to get working directory", err)
os.Exit(1)
}
// Apply schema from schema.sql
schemaPath := filepath.Join(rootDir, "..", "schema.sql")
if err := applySchema(db, schemaPath, rootDir, logger); err != nil {
logger.Error(ctx, "Failed to apply schema", err)
os.Exit(1)
}
// Initialize services
userService := services.NewUserServiceWithLogger(db, cfg, logger)
learningService := services.NewLearningServiceWithLogger(db, cfg, logger)
// Create question service
questionService := services.NewQuestionServiceWithLogger(db, learningService, cfg, logger)
// Ensure admin user exists
if err := userService.EnsureAdminUserExists(ctx, "admin", "password"); err != nil {
logger.Error(ctx, "Failed to ensure admin user exists", err)
os.Exit(1)
}
// Load and insert test data
users, err := setupTestData(ctx, rootDir, userService, questionService, learningService, db, logger)
if err != nil {
logger.Error(ctx, "Failed to setup test data", err)
os.Exit(1)
}
// Output user data to JSON file for E2E tests
if err := outputUserDataForTests(users, rootDir, logger); err != nil {
logger.Error(ctx, "Failed to output user data for tests", err)
os.Exit(1)
}
// Output roles data to JSON file for E2E tests
if err := outputRolesDataForTests(db, rootDir, logger); err != nil {
logger.Error(ctx, "Failed to output roles data for tests", err)
os.Exit(1)
}
logger.Info(ctx, "Test database created successfully")
}
func applySchema(db *sql.DB, schemaPath, _ string, logger *observability.Logger) error {
ctx := context.Background()
// First, drop all existing tables and sequences to ensure clean state
logger.Info(ctx, "Dropping existing tables and sequences")
dropSQL := `
-- Drop tables in reverse dependency order
DROP TABLE IF EXISTS performance_metrics CASCADE;
DROP TABLE IF EXISTS user_responses CASCADE;
DROP TABLE IF EXISTS questions CASCADE;
DROP TABLE IF EXISTS users CASCADE;
-- Drop any remaining sequences (in case they weren't cleaned up)
DROP SEQUENCE IF EXISTS users_id_seq CASCADE;
DROP SEQUENCE IF EXISTS questions_id_seq CASCADE;
DROP SEQUENCE IF EXISTS user_responses_id_seq CASCADE;
DROP SEQUENCE IF EXISTS performance_metrics_id_seq CASCADE;
`
if _, err := db.Exec(dropSQL); err != nil {
return contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to drop existing tables: %w", err)
}
// Now apply the schema
logger.Info(ctx, "Applying schema")
schemaSQL, err := os.ReadFile(schemaPath)
if err != nil {
return contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to read schema file: %w", err)
}
if _, err := db.Exec(string(schemaSQL)); err != nil {
return contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to execute schema: %w", err)
}
// Priority system tables are already included in the main schema.sql
// No additional migration needed
logger.Info(ctx, "Priority system tables already included in main schema")
return nil
}
func setupTestData(ctx context.Context, rootDir string, userService *services.UserService, questionService *services.QuestionService, learningService *services.LearningService, db *sql.DB, logger *observability.Logger) (map[string]*models.User, error) {
dataDir := filepath.Join(rootDir, "data")
// 1. Load and create users
users, err := loadAndCreateUsers(ctx, filepath.Join(dataDir, "test_users.yaml"), userService, logger)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup users: %w", err)
}
// 2. Load and create questions
questions, err := loadAndCreateQuestions(ctx, filepath.Join(dataDir, "test_questions.yaml"), questionService, users, logger)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup questions: %w", err)
}
// 3. Load and create user responses
if err := loadAndCreateResponses(ctx, filepath.Join(dataDir, "test_responses.yaml"), users, questions, learningService, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup responses: %w", err)
}
// 4. Load and create question reports
if err := loadAndCreateQuestionReports(ctx, filepath.Join(dataDir, "test_responses.yaml"), users, questions, db, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup question reports: %w", err)
}
// 5. Load and create analytics data
if err := loadAndCreateAnalytics(ctx, filepath.Join(dataDir, "test_analytics.yaml"), users, questions, learningService, db, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup analytics: %w", err)
}
// 6. Load and create daily assignments
if err := loadAndCreateDailyAssignments(ctx, filepath.Join(dataDir, "test_daily_assignments.yaml"), users, questions, db, logger); err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to setup daily assignments: %w", err)
}
return users, nil
}
func loadAndCreateUsers(ctx context.Context, filePath string, userService *services.UserService, logger *observability.Logger) (result0 map[string]*models.User, err error) {
data, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
var testUsers TestUsers
if err := yaml.Unmarshal(data, &testUsers); err != nil {
return nil, err
}
users := make(map[string]*models.User)
for _, testUser := range testUsers.Users {
// Create user with email and timezone
user, err := userService.CreateUserWithEmailAndTimezone(
ctx,
testUser.Username,
testUser.Email,
"UTC", // Default timezone for test users
testUser.PreferredLanguage,
testUser.CurrentLevel,
)
if err != nil {
return nil, contextutils.WrapErrorf(err, "failed to create user %s", testUser.Username)
}
// Set password separately since CreateUserWithEmailAndTimezone doesn't set password
if err := userService.UpdateUserPassword(ctx, user.ID, testUser.Password); err != nil {
return nil, contextutils.WrapErrorf(err, "failed to set password for user %s", testUser.Username)
}
// Update additional settings
settings := &models.UserSettings{
Language: testUser.PreferredLanguage,
Level: testUser.CurrentLevel,
AIProvider: testUser.AIProvider,
AIModel: testUser.AIModel,
AIAPIKey: testUser.AIAPIKey,
AIEnabled: testUser.AIProvider != "", // Enable AI if provider is set
}
if err := userService.UpdateUserSettings(ctx, user.ID, settings); err != nil {
return nil, contextutils.WrapErrorf(err, "failed to update settings for user %s", testUser.Username)
}
// Assign roles from YAML configuration
for _, roleName := range testUser.Roles {
err = userService.AssignRoleByName(ctx, user.ID, roleName)
if err != nil {
logger.Warn(ctx, "Failed to assign role to user", map[string]interface{}{
"username": testUser.Username,
"role": roleName,
"error": err.Error(),
})
} else {
logger.Info(ctx, "Assigned role to user", map[string]interface{}{
"username": testUser.Username,
"role": roleName,
"user_id": user.ID,
})
}
}
users[testUser.Username] = user
}
return users, nil
}
func loadAndCreateQuestions(ctx context.Context, filePath string, questionService *services.QuestionService, users map[string]*models.User, _ *observability.Logger) (result0 []*models.Question, err error) {
data, err := os.ReadFile(filePath)
if err != nil {
return nil, err
}
var testQuestions TestQuestions
if err := yaml.Unmarshal(data, &testQuestions); err != nil {
return nil, err
}
var questions []*models.Question
for i, question := range testQuestions.Questions {
// Set the created time since it's not in YAML
question.CreatedAt = time.Now()
// Get the users this question should be assigned to
questionUsers := question.Users
var assignedUserIDs []int
if len(questionUsers) == 0 {
// Fallback to round-robin if no users specified
for _, user := range users {
assignedUserIDs = append(assignedUserIDs, user.ID)
}
if len(assignedUserIDs) == 0 {
return nil, contextutils.ErrorWithContextf("no users available to assign questions to")
}
// Assign to one user in round-robin
assignedUserIDs = []int{assignedUserIDs[i%len(assignedUserIDs)]}
} else {
for _, username := range questionUsers {
user, exists := users[username]
if !exists {
return nil, contextutils.ErrorWithContextf("user not found: %s", username)
}
assignedUserIDs = append(assignedUserIDs, user.ID)
}
}
if err := questionService.SaveQuestion(ctx, &question); err != nil {
return nil, contextutils.WrapErrorf(err, "failed to save question %d", i)
}
for _, userID := range assignedUserIDs {
if err := questionService.AssignQuestionToUser(ctx, question.ID, userID); err != nil {
return nil, contextutils.WrapErrorf(err, "failed to assign question %d to user %d", question.ID, userID)
}
}
questions = append(questions, &question)
}
return questions, nil
}
func loadAndCreateResponses(_ context.Context, filePath string, users map[string]*models.User, questions []*models.Question, learningService *services.LearningService, _ *observability.Logger) error {
data, err := os.ReadFile(filePath)
if err != nil {
return err
}
var testResponses TestResponses
if err := yaml.Unmarshal(data, &testResponses); err != nil {
return err
}
for i, responseData := range testResponses.UserResponses {
user, exists := users[responseData.Username]
if !exists {
return contextutils.ErrorWithContextf("user not found: %s", responseData.Username)
}
if responseData.QuestionIndex >= len(questions) {
return contextutils.ErrorWithContextf("question index out of range: %d", responseData.QuestionIndex)
}
question := questions[responseData.QuestionIndex]
// Use RecordAnswerWithPriority to ensure priority scores are calculated
if err := learningService.RecordAnswerWithPriority(
context.Background(),
user.ID,
question.ID,
0, // Use index 0 for test data
responseData.IsCorrect,
responseData.ResponseTimeMs,
); err != nil {
return contextutils.WrapErrorf(err, "failed to record response %d", i)
}
}
return nil
}
func loadAndCreateQuestionReports(_ context.Context, filePath string, users map[string]*models.User, questions []*models.Question, db *sql.DB, _ *observability.Logger) error {
data, err := os.ReadFile(filePath)
if err != nil {
return contextutils.WrapError(err, "failed to read responses file")
}
var testResponses TestResponses
if err := yaml.Unmarshal(data, &testResponses); err != nil {
return contextutils.WrapError(err, "failed to parse responses data")
}
// Load question reports
for i, reportData := range testResponses.QuestionReports {
user, exists := users[reportData.Username]
if !exists {
return contextutils.ErrorWithContextf("user not found for question report: %s", reportData.Username)
}
if reportData.QuestionIndex >= len(questions) {
return contextutils.ErrorWithContextf("question index out of range for question report: %d", reportData.QuestionIndex)
}
question := questions[reportData.QuestionIndex]
// Parse the timestamp if provided, otherwise use current time
var createdAt time.Time
if reportData.CreatedAt != nil {
var err error
createdAt, err = time.Parse(time.RFC3339, *reportData.CreatedAt)
if err != nil {
return contextutils.ErrorWithContextf("invalid timestamp format for question report: %s", *reportData.CreatedAt)
}
} else {
createdAt = time.Now()
}
// Insert question report directly into database
_, err := db.Exec(`
INSERT INTO question_reports (question_id, reported_by_user_id, report_reason, created_at)
VALUES ($1, $2, $3, $4)
ON CONFLICT (question_id, reported_by_user_id) DO NOTHING
`, question.ID, user.ID, reportData.ReportReason, createdAt)
if err != nil {
return contextutils.WrapErrorf(err, "failed to insert question report %d", i)
}
}
return nil
}
func loadAndCreateAnalytics(ctx context.Context, filePath string, users map[string]*models.User, questions []*models.Question, learningService *services.LearningService, db *sql.DB, logger *observability.Logger) error {
data, err := os.ReadFile(filePath)
if err != nil {
// Analytics file is optional, so just return if it doesn't exist
logger.Warn(ctx, "Analytics file not found", map[string]interface{}{"file_path": filePath})
return nil
}
var testAnalytics TestAnalytics
if err := yaml.Unmarshal(data, &testAnalytics); err != nil {
return contextutils.WrapError(err, "failed to parse analytics data")
}
// Load priority scores
for _, priorityData := range testAnalytics.PriorityScores {
user, exists := users[priorityData.Username]
if !exists {
return contextutils.ErrorWithContextf("user not found for priority score: %s", priorityData.Username)
}
if priorityData.QuestionIndex >= len(questions) {
return contextutils.ErrorWithContextf("question index out of range for priority score: %d", priorityData.QuestionIndex)
}
question := questions[priorityData.QuestionIndex]
// Parse the timestamp
lastCalculatedAt, err := time.Parse(time.RFC3339, priorityData.LastCalculatedAt)
if err != nil {
return contextutils.ErrorWithContextf("invalid timestamp format for priority score: %s", priorityData.LastCalculatedAt)
}
// Insert priority score directly into database
_, err = db.Exec(`
INSERT INTO question_priority_scores (user_id, question_id, priority_score, last_calculated_at, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (user_id, question_id) DO UPDATE SET
priority_score = EXCLUDED.priority_score,
last_calculated_at = EXCLUDED.last_calculated_at,
updated_at = NOW()
`, user.ID, question.ID, priorityData.PriorityScore, lastCalculatedAt)
if err != nil {
return contextutils.WrapError(err, "failed to insert priority score")
}
}
// Load learning preferences
for _, prefData := range testAnalytics.LearningPreferences {
user, exists := users[prefData.Username]
if !exists {
return contextutils.ErrorWithContextf("user not found for learning preferences: %s", prefData.Username)
}
// Ensure daily_goal is present and valid. The schema enforces daily_goal > 0
// so default to the service's default if not provided or invalid.
dailyGoal := 0
// Try to parse a daily_goal field if it exists in the YAML by checking for a map
// fallback: the YAML struct doesn't include daily_goal currently; use default
// from the LearningService defaults.
// We'll fetch defaults from service to avoid duplicating magic numbers.
defaultPrefs := learningService.GetDefaultLearningPreferences()
if dailyGoal <= 0 {
dailyGoal = defaultPrefs.DailyGoal
}
prefs := &models.UserLearningPreferences{
UserID: user.ID,
FocusOnWeakAreas: prefData.FocusOnWeakAreas,
FreshQuestionRatio: prefData.FreshQuestionRatio,
WeakAreaBoost: prefData.WeakAreaBoost,
KnownQuestionPenalty: prefData.KnownQuestionPenalty,
ReviewIntervalDays: prefData.ReviewIntervalDays,
DailyReminderEnabled: prefData.DailyReminderEnabled,
DailyGoal: dailyGoal,
}
if _, err := learningService.UpdateUserLearningPreferences(ctx, user.ID, prefs); err != nil {
return contextutils.WrapErrorf(err, "failed to update learning preferences for user %s", prefData.Username)
}
}
// Load performance metrics
for _, metricData := range testAnalytics.PerformanceMetrics {
user, exists := users[metricData.Username]
if !exists {
return contextutils.ErrorWithContextf("user not found for performance metrics: %s", metricData.Username)
}
// Insert performance metric directly into database
_, err := db.Exec(`
INSERT INTO performance_metrics (user_id, topic, language, level, total_attempts, correct_attempts, average_response_time_ms, last_updated)
VALUES ($1, $2, $3, $4, $5, $6, $7, NOW())
ON CONFLICT (user_id, topic, language, level) DO UPDATE SET
total_attempts = EXCLUDED.total_attempts,
correct_attempts = EXCLUDED.correct_attempts,
average_response_time_ms = EXCLUDED.average_response_time_ms,
last_updated = NOW()
`, user.ID, metricData.Topic, metricData.Language, metricData.Level,
metricData.TotalAttempts, metricData.CorrectAttempts, metricData.AverageResponseTimeMs)
if err != nil {
return contextutils.WrapError(err, "failed to insert performance metric")
}
}
// Load user question metadata (marked as known)
for _, metadata := range testAnalytics.UserQuestionMetadata {
user, exists := users[metadata.Username]
if !exists {
return contextutils.ErrorWithContextf("user not found for question metadata: %s", metadata.Username)
}
if metadata.QuestionIndex >= len(questions) {
return contextutils.ErrorWithContextf("question index out of range for metadata: %d", metadata.QuestionIndex)
}
question := questions[metadata.QuestionIndex]
if metadata.MarkedAsKnown {
var markedAt time.Time
if metadata.MarkedAsKnownAt != nil {
var err error
markedAt, err = time.Parse(time.RFC3339, *metadata.MarkedAsKnownAt)
if err != nil {
return contextutils.ErrorWithContextf("invalid timestamp format for marked as known: %s", *metadata.MarkedAsKnownAt)
}
} else {
markedAt = time.Now()
}
// Insert into user_question_metadata table
_, err := db.Exec(`
INSERT INTO user_question_metadata (user_id, question_id, marked_as_known, marked_as_known_at, created_at, updated_at)
VALUES ($1, $2, $3, $4, NOW(), NOW())
ON CONFLICT (user_id, question_id) DO UPDATE SET
marked_as_known = EXCLUDED.marked_as_known,
marked_as_known_at = EXCLUDED.marked_as_known_at,
updated_at = NOW()
`, user.ID, question.ID, metadata.MarkedAsKnown, markedAt)
if err != nil {
return contextutils.WrapError(err, "failed to insert question metadata")
}
}
}
return nil
}
func loadAndCreateDailyAssignments(ctx context.Context, filePath string, users map[string]*models.User, questions []*models.Question, db *sql.DB, logger *observability.Logger) error {
data, err := os.ReadFile(filePath)
if err != nil {
// File doesn't exist, skip daily assignments
logger.Info(ctx, "Daily assignments file not found, skipping", map[string]interface{}{
"file_path": filePath,
})
return nil
}
var testDailyAssignments TestDailyAssignments
if err := yaml.Unmarshal(data, &testDailyAssignments); err != nil {
return err
}
for _, assignmentData := range testDailyAssignments.DailyAssignments {
user, exists := users[assignmentData.Username]
if !exists {
logger.Warn(ctx, "User not found for daily assignment", map[string]interface{}{
"username": assignmentData.Username,
})
continue
}
// Parse the date
date, err := time.Parse("2006-01-02", assignmentData.Date)
if err != nil {
logger.Warn(ctx, "Invalid date format for daily assignment", map[string]interface{}{
"username": assignmentData.Username,
"date": assignmentData.Date,
})
continue
}
// Create a map of completed questions for quick lookup
completedQuestions := make(map[int]bool)
for _, qID := range assignmentData.CompletedQuestions {
completedQuestions[qID] = true
}
// Assign questions to the user for the specific date
for _, questionID := range assignmentData.QuestionIDs {
// Check if question exists
if questionID <= 0 || questionID > len(questions) {
logger.Warn(ctx, "Question ID out of range for daily assignment", map[string]interface{}{
"username": assignmentData.Username,
"date": assignmentData.Date,
"question_id": questionID,
})
continue
}
question := questions[questionID-1] // Convert to 0-based index
// Ensure we don't violate unique constraint by removing any existing assignment for the same
// (user_id, question_id, assignment_date) tuple before inserting. This avoids relying on
// ON CONFLICT which requires the constraint to be present in some test DB states.
deleteQuery := `DELETE FROM daily_question_assignments WHERE user_id = $1 AND question_id = $2 AND assignment_date = $3`
if _, err := db.ExecContext(ctx, deleteQuery, user.ID, question.ID, date); err != nil {
logger.Error(ctx, "Failed to delete existing daily assignment", err, map[string]interface{}{
"username": assignmentData.Username,
"date": assignmentData.Date,
"question_id": questionID,
})
return contextutils.WrapErrorf(err, "failed to delete existing daily assignment for user %s, question %d", assignmentData.Username, questionID)
}
// Insert the assignment directly into the database
query := `
INSERT INTO daily_question_assignments (user_id, question_id, assignment_date, is_completed, completed_at)
VALUES ($1, $2, $3, $4, $5)
`
isCompleted := completedQuestions[questionID]
var completedAt *time.Time
if isCompleted {
now := time.Now()
completedAt = &now
}
if _, err := db.ExecContext(ctx, query, user.ID, question.ID, date, isCompleted, completedAt); err != nil {
logger.Error(ctx, "Failed to create daily assignment", err, map[string]interface{}{
"username": assignmentData.Username,
"date": assignmentData.Date,
"question_id": questionID,
})
return contextutils.WrapErrorf(err, "failed to create daily assignment for user %s, question %d", assignmentData.Username, questionID)
}
}
logger.Info(ctx, "Created daily assignments", map[string]interface{}{
"username": assignmentData.Username,
"date": assignmentData.Date,
"count": len(assignmentData.QuestionIDs),
})
}
return nil
}
// outputUserDataForTests outputs the created user data to a JSON file for E2E tests to read
func outputUserDataForTests(users map[string]*models.User, rootDir string, logger *observability.Logger) error {
// Create a simplified structure for the E2E test
type TestUserData struct {
ID int `json:"id"`
Username string `json:"username"`
Email string `json:"email"`
}
userData := make(map[string]TestUserData)
for username, user := range users {
userData[username] = TestUserData{
ID: user.ID,
Username: user.Username,
Email: user.Email.String,
}
}
// Write to JSON file in the frontend/tests directory
outputPath := filepath.Join(rootDir, "..", "frontend", "tests", "test-users.json")
// Ensure the directory exists
outputDir := filepath.Dir(outputPath)
if err := os.MkdirAll(outputDir, 0o755); err != nil {
return contextutils.WrapErrorf(err, "failed to create output directory: %s", outputDir)
}
// Marshal to JSON with pretty printing
jsonData, err := json.MarshalIndent(userData, "", " ")
if err != nil {
return contextutils.WrapErrorf(err, "failed to marshal user data to JSON")
}
// Write to file
if err := os.WriteFile(outputPath, jsonData, 0o644); err != nil {
return contextutils.WrapErrorf(err, "failed to write user data to file: %s", outputPath)
}
logger.Info(context.Background(), "Output user data for E2E tests", map[string]interface{}{
"file_path": outputPath,
"user_count": len(userData),
})
return nil
}
// outputRolesDataForTests outputs the created roles data to a JSON file for E2E tests to read
func outputRolesDataForTests(db *sql.DB, rootDir string, logger *observability.Logger) error {
// Query all roles from the database
rows, err := db.Query(`
SELECT id, name, description, created_at, updated_at
FROM roles
ORDER BY id
`)
if err != nil {
return contextutils.WrapErrorf(err, "failed to query roles from database")
}
defer func() {
if err := rows.Close(); err != nil {
logger.Warn(context.Background(), "Warning: failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
// Create a simplified structure for the E2E test
type TestRoleData struct {
ID int `json:"id"`
Name string `json:"name"`
Description string `json:"description"`
}
roleData := make(map[string]TestRoleData)
for rows.Next() {
var role models.Role
err := rows.Scan(&role.ID, &role.Name, &role.Description, &role.CreatedAt, &role.UpdatedAt)
if err != nil {
return contextutils.WrapErrorf(err, "failed to scan role data")
}
roleData[role.Name] = TestRoleData{
ID: role.ID,
Name: role.Name,
Description: role.Description,
}
}
if err := rows.Err(); err != nil {
return contextutils.WrapErrorf(err, "error iterating over roles")
}
// Write to JSON file in the frontend/tests directory
outputPath := filepath.Join(rootDir, "..", "frontend", "tests", "test-roles.json")
// Ensure the directory exists
outputDir := filepath.Dir(outputPath)
if err := os.MkdirAll(outputDir, 0o755); err != nil {
return contextutils.WrapErrorf(err, "failed to create output directory: %s", outputDir)
}
// Marshal to JSON with pretty printing
jsonData, err := json.MarshalIndent(roleData, "", " ")
if err != nil {
return contextutils.WrapErrorf(err, "failed to marshal roles data to JSON")
}
// Write to file
if err := os.WriteFile(outputPath, jsonData, 0o644); err != nil {
return contextutils.WrapErrorf(err, "failed to write roles data to file: %s", outputPath)
}
logger.Info(context.Background(), "Output roles data for E2E tests", map[string]interface{}{
"file_path": outputPath,
"roles_count": len(roleData),
})
return nil
}
// Package main provides the entry point for the Quiz Application worker service.
package main
import (
"context"
"io/fs"
"net/http"
"os"
"os/signal"
"syscall"
"time"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/handlers"
"quizapp/internal/middleware"
"quizapp/internal/observability"
"quizapp/internal/services"
"quizapp/internal/version"
"quizapp/internal/worker"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
)
// fatalIfErr logs the error with context and panics with a consistent message
func fatalIfErr(ctx context.Context, logger *observability.Logger, msg string, err error, fields map[string]interface{}) {
logger.Error(ctx, msg, err, fields)
panic(msg + ": " + err.Error())
}
func main() {
ctx := context.Background()
// Load configuration
cfg, err := config.NewConfig()
if err != nil {
panic("Failed to load configuration: " + err.Error())
}
// Setup observability (tracing/metrics/logging)
tp, mp, logger, err := observability.SetupObservability(&cfg.OpenTelemetry, "quiz-worker")
if err != nil {
panic("Failed to initialize observability: " + err.Error())
}
defer func() {
if tp != nil {
if err := tp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down tracer provider", map[string]interface{}{"error": err.Error(), "provider": "tracer"})
}
}
if mp != nil {
if err := mp.Shutdown(context.TODO()); err != nil {
logger.Warn(ctx, "Error shutting down meter provider", map[string]interface{}{"error": err.Error(), "provider": "meter"})
}
}
}()
logger.Info(ctx, "Starting quiz worker service", map[string]interface{}{
"port": cfg.Server.WorkerPort,
"logLevel": cfg.Server.LogLevel,
"debug": cfg.Server.Debug,
})
// Initialize database manager with logger
dbManager := database.NewManager(logger)
// Initialize database connection with configuration (no migrations for worker)
db, err := dbManager.InitDBWithoutMigrations(cfg.Database)
if err != nil {
fatalIfErr(ctx, logger, "Failed to initialize database", err, map[string]interface{}{"db_url": cfg.Database.URL})
}
defer func() {
if err := db.Close(); err != nil {
logger.Warn(ctx, "Warning: failed to close database", map[string]interface{}{"error": err.Error(), "db_url": cfg.Database.URL})
}
}()
// Initialize services
userService := services.NewUserServiceWithLogger(db, cfg, logger)
learningService := services.NewLearningServiceWithLogger(db, cfg, logger)
// Create question service
questionService := services.NewQuestionServiceWithLogger(db, learningService, cfg, logger)
aiService := services.NewAIService(cfg, logger)
workerService := services.NewWorkerServiceWithLogger(db, logger)
generationHintService := services.NewGenerationHintService(db, logger)
emailService := services.CreateEmailServiceWithDB(cfg, logger, db)
// Create daily question service
dailyQuestionService := services.NewDailyQuestionService(db, logger, questionService, learningService)
// Initialize worker with the observability logger
workerInstance := worker.NewWorker(userService, questionService, aiService, learningService, workerService, dailyQuestionService, emailService, generationHintService, "default", cfg, logger)
go workerInstance.Start(ctx)
// Initialize admin handler for worker UI
adminHandler := handlers.NewWorkerAdminHandlerWithLogger(userService, questionService, aiService, cfg, workerInstance, workerService, learningService, dailyQuestionService, logger)
// Setup Gin router
gin.SetMode(gin.ReleaseMode)
if cfg.Server.Debug {
gin.SetMode(gin.DebugMode)
}
router := gin.New()
router.Use(gin.Recovery())
// Add HTTP request logging middleware using our observability logger
router.Use(func(c *gin.Context) {
start := time.Now()
// Process request
c.Next()
// Log request details using our observability logger
latency := time.Since(start)
statusCode := c.Writer.Status()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Create structured log entry
fields := map[string]interface{}{
"http.method": method,
"http.path": path,
"http.status_code": statusCode,
"http.latency_ms": latency.Milliseconds(),
"http.client_ip": clientIP,
"http.user_agent": c.Request.UserAgent(),
}
// Add error message if present
if len(c.Errors) > 0 {
fields["http.error"] = c.Errors.String()
}
// Log using our observability logger (goes to both stdout and OTLP)
// Use appropriate log level based on status code
if statusCode >= 500 {
logger.Error(c.Request.Context(), "HTTP request failed", nil, fields)
} else if statusCode >= 400 {
logger.Warn(c.Request.Context(), "HTTP request warning", fields)
} else {
logger.Info(c.Request.Context(), "HTTP request", fields)
}
})
// Add OpenTelemetry middleware for HTTP tracing with automatic error attributes
router.Use(observability.GinMiddlewareWithErrorHandling("quiz-worker"))
// Add CORS middleware
router.Use(func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "Origin, Content-Type, Content-Length, Accept-Encoding, X-CSRF-Token, Authorization")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(204)
return
}
c.Next()
})
// Setup session middleware
store := cookie.NewStore([]byte(cfg.Server.SessionSecret))
router.Use(sessions.Sessions(config.SessionName, store))
// Setup routes
v1 := router.Group("/v1")
{
// Health check route
v1.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
})
// Version route
v1.GET("/version", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"service": "worker",
"version": version.Version,
"commit": version.Commit,
"buildTime": version.BuildTime,
})
})
}
// Serve static assets (CSS/JS) for worker admin dashboard
staticFS, _ := fs.Sub(handlers.AssetsFS, "templates/assets")
router.StaticFS("/worker", http.FS(staticFS))
// Config dump endpoint
router.GET("/configz", adminHandler.GetConfigz)
// API routes for worker management
api := router.Group("/v1")
{
// Admin worker endpoints (for frontend)
adminWorker := api.Group("/admin/worker")
adminWorker.Use(middleware.RequireAuth())
{
adminWorker.GET("/details", adminHandler.GetWorkerDetails)
adminWorker.GET("/status", adminHandler.GetWorkerStatus)
adminWorker.GET("/logs", adminHandler.GetActivityLogs)
adminWorker.POST("/pause", adminHandler.PauseWorker)
adminWorker.POST("/resume", adminHandler.ResumeWorker)
adminWorker.POST("/trigger", adminHandler.TriggerWorkerRun)
adminWorker.GET("/ai-concurrency", adminHandler.GetAIConcurrencyStats)
}
// Worker user control endpoints (for pausing/resuming user question generation)
workerUsers := api.Group("/admin/worker/users")
workerUsers.Use(middleware.RequireAuth())
{
workerUsers.GET("/", adminHandler.GetWorkerUsers)
workerUsers.POST("/pause", adminHandler.PauseWorkerUser)
workerUsers.POST("/resume", adminHandler.ResumeWorkerUser)
}
// System health for worker
system := api.Group("/system")
{
system.GET("/health", adminHandler.GetSystemHealth)
}
// Admin analytics endpoints (for frontend)
adminAnalytics := api.Group("/admin/worker/analytics")
adminAnalytics.Use(middleware.RequireAuth())
{
adminAnalytics.GET("/priority-scores", adminHandler.GetPriorityAnalytics)
adminAnalytics.GET("/user-performance", adminHandler.GetUserPerformanceAnalytics)
adminAnalytics.GET("/generation-intelligence", adminHandler.GetGenerationIntelligence)
adminAnalytics.GET("/system-health", adminHandler.GetSystemHealthAnalytics)
adminAnalytics.GET("/comparison", adminHandler.GetUserComparisonAnalytics)
adminAnalytics.GET("/user/:userID", adminHandler.GetUserPriorityAnalytics)
}
// Admin daily questions endpoints (for frontend)
adminDaily := api.Group("/admin/worker/daily")
adminDaily.Use(middleware.RequireAuth())
{
adminDaily.GET("/users/:userId/questions/:date", adminHandler.GetUserDailyQuestions)
adminDaily.POST("/users/:userId/questions/:date/regenerate", adminHandler.RegenerateUserDailyQuestions)
}
// Admin notification endpoints (for frontend)
adminNotifications := api.Group("/admin/worker/notifications")
adminNotifications.Use(middleware.RequireAuth())
{
adminNotifications.GET("/stats", adminHandler.GetNotificationStats)
adminNotifications.GET("/errors", adminHandler.GetNotificationErrors)
adminNotifications.GET("/sent", adminHandler.GetSentNotifications)
adminNotifications.POST("/test/create-sent", adminHandler.CreateTestSentNotification)
adminNotifications.POST("/force-send", adminHandler.ForceSendNotification)
}
}
// Automatic route listing at root path
routeListing := handlers.NewRouteListingHandler("Worker")
routeListing.CollectRoutes(router)
// Root path shows all available routes
router.GET("/", func(c *gin.Context) {
// Support JSON output via query parameter
if c.Query("json") == "true" {
routeListing.GetRouteListingJSON(c)
} else {
routeListing.GetRouteListingPage(c)
}
})
// Create HTTP server
srv := &http.Server{
Addr: ":" + cfg.Server.WorkerPort,
Handler: router,
}
// Start server in a goroutine
go func() {
logger.Info(ctx, "Worker server starting", map[string]interface{}{"port": cfg.Server.WorkerPort})
if err := srv.ListenAndServe(); err != nil && err != http.ErrServerClosed {
fatalIfErr(ctx, logger, "Failed to start worker server", err, map[string]interface{}{"port": cfg.Server.WorkerPort})
}
}()
// Wait for interrupt signal to gracefully shutdown
quit := make(chan os.Signal, 1)
signal.Notify(quit, syscall.SIGINT, syscall.SIGTERM)
<-quit
logger.Info(ctx, "Worker server shutting down", map[string]interface{}{"service": "worker"})
// Graceful shutdown with timeout
shutdownCtx, shutdownCancel := context.WithTimeout(ctx, config.WorkerShutdownTimeout)
defer shutdownCancel()
// Shutdown the worker first
if err := workerInstance.Shutdown(shutdownCtx); err != nil {
logger.Warn(ctx, "Warning: failed to shutdown worker", map[string]interface{}{"error": err.Error(), "service": "worker"})
}
// Then shutdown the server
if err := srv.Shutdown(shutdownCtx); err != nil {
fatalIfErr(ctx, logger, "Worker server forced to shutdown", err, map[string]interface{}{"service": "worker"})
}
logger.Info(ctx, "Worker server exited", map[string]interface{}{"service": "worker"})
}
// Package api provides primitives to interact with the openapi HTTP API.
//
// Code generated by github.com/oapi-codegen/oapi-codegen/v2 version v2.5.0 DO NOT EDIT.
package api
import (
"encoding/json"
"fmt"
"github.com/oapi-codegen/runtime"
openapi_types "github.com/oapi-codegen/runtime/types"
)
const (
CookieAuthScopes = "cookieAuth.Scopes"
SessionAuthScopes = "sessionAuth.Scopes"
)
// Defines values for ChatMessageRole.
const (
ChatMessageRoleAssistant ChatMessageRole = "assistant"
ChatMessageRoleUser ChatMessageRole = "user"
)
// Defines values for NotificationErrorErrorType.
const (
NotificationErrorErrorTypeEmailDisabled NotificationErrorErrorType = "email_disabled"
NotificationErrorErrorTypeOther NotificationErrorErrorType = "other"
NotificationErrorErrorTypeSmtpError NotificationErrorErrorType = "smtp_error"
NotificationErrorErrorTypeTemplateError NotificationErrorErrorType = "template_error"
NotificationErrorErrorTypeUserNotFound NotificationErrorErrorType = "user_not_found"
)
// Defines values for NotificationErrorNotificationType.
const (
NotificationErrorNotificationTypeDailyReminder NotificationErrorNotificationType = "daily_reminder"
NotificationErrorNotificationTypeTestEmail NotificationErrorNotificationType = "test_email"
)
// Defines values for QuestionStatus.
const (
Active QuestionStatus = "active"
Reported QuestionStatus = "reported"
)
// Defines values for QuestionType.
const (
FillBlank QuestionType = "fill_blank"
Qa QuestionType = "qa"
ReadingComprehension QuestionType = "reading_comprehension"
Vocabulary QuestionType = "vocabulary"
)
// Defines values for SentNotificationNotificationType.
const (
SentNotificationNotificationTypeDailyReminder SentNotificationNotificationType = "daily_reminder"
SentNotificationNotificationTypeTestEmail SentNotificationNotificationType = "test_email"
)
// Defines values for SentNotificationStatus.
const (
SentNotificationStatusBounced SentNotificationStatus = "bounced"
SentNotificationStatusFailed SentNotificationStatus = "failed"
SentNotificationStatusSent SentNotificationStatus = "sent"
)
// Defines values for TTSRequestStreamFormat.
const (
Mp3 TTSRequestStreamFormat = "mp3"
Sse TTSRequestStreamFormat = "sse"
Wav TTSRequestStreamFormat = "wav"
)
// Defines values for TTSResponseType.
const (
TTSResponseTypeAudio TTSResponseType = "audio"
TTSResponseTypeError TTSResponseType = "error"
TTSResponseTypeUsage TTSResponseType = "usage"
)
// Defines values for WorkerStatusStatus.
const (
WorkerStatusStatusBusy WorkerStatusStatus = "busy"
WorkerStatusStatusError WorkerStatusStatus = "error"
WorkerStatusStatusIdle WorkerStatusStatus = "idle"
)
// Defines values for GetV1AdminBackendUserzPaginatedParamsAiEnabled.
const (
GetV1AdminBackendUserzPaginatedParamsAiEnabledFalse GetV1AdminBackendUserzPaginatedParamsAiEnabled = "false"
GetV1AdminBackendUserzPaginatedParamsAiEnabledTrue GetV1AdminBackendUserzPaginatedParamsAiEnabled = "true"
)
// Defines values for GetV1AdminBackendUserzPaginatedParamsActive.
const (
GetV1AdminBackendUserzPaginatedParamsActiveFalse GetV1AdminBackendUserzPaginatedParamsActive = "false"
GetV1AdminBackendUserzPaginatedParamsActiveTrue GetV1AdminBackendUserzPaginatedParamsActive = "true"
)
// Defines values for GetV1AdminWorkerNotificationsErrorsParamsErrorType.
const (
GetV1AdminWorkerNotificationsErrorsParamsErrorTypeEmailDisabled GetV1AdminWorkerNotificationsErrorsParamsErrorType = "email_disabled"
GetV1AdminWorkerNotificationsErrorsParamsErrorTypeOther GetV1AdminWorkerNotificationsErrorsParamsErrorType = "other"
GetV1AdminWorkerNotificationsErrorsParamsErrorTypeSmtpError GetV1AdminWorkerNotificationsErrorsParamsErrorType = "smtp_error"
GetV1AdminWorkerNotificationsErrorsParamsErrorTypeTemplateError GetV1AdminWorkerNotificationsErrorsParamsErrorType = "template_error"
GetV1AdminWorkerNotificationsErrorsParamsErrorTypeUserNotFound GetV1AdminWorkerNotificationsErrorsParamsErrorType = "user_not_found"
)
// Defines values for GetV1AdminWorkerNotificationsErrorsParamsNotificationType.
const (
GetV1AdminWorkerNotificationsErrorsParamsNotificationTypeDailyReminder GetV1AdminWorkerNotificationsErrorsParamsNotificationType = "daily_reminder"
GetV1AdminWorkerNotificationsErrorsParamsNotificationTypeTestEmail GetV1AdminWorkerNotificationsErrorsParamsNotificationType = "test_email"
)
// Defines values for GetV1AdminWorkerNotificationsErrorsParamsResolved.
const (
False GetV1AdminWorkerNotificationsErrorsParamsResolved = "false"
True GetV1AdminWorkerNotificationsErrorsParamsResolved = "true"
)
// Defines values for GetV1AdminWorkerNotificationsSentParamsNotificationType.
const (
GetV1AdminWorkerNotificationsSentParamsNotificationTypeDailyReminder GetV1AdminWorkerNotificationsSentParamsNotificationType = "daily_reminder"
GetV1AdminWorkerNotificationsSentParamsNotificationTypeTestEmail GetV1AdminWorkerNotificationsSentParamsNotificationType = "test_email"
)
// Defines values for GetV1AdminWorkerNotificationsSentParamsStatus.
const (
GetV1AdminWorkerNotificationsSentParamsStatusBounced GetV1AdminWorkerNotificationsSentParamsStatus = "bounced"
GetV1AdminWorkerNotificationsSentParamsStatusFailed GetV1AdminWorkerNotificationsSentParamsStatus = "failed"
GetV1AdminWorkerNotificationsSentParamsStatusSent GetV1AdminWorkerNotificationsSentParamsStatus = "sent"
)
// AIConcurrencyStats defines model for AIConcurrencyStats.
type AIConcurrencyStats struct {
ActiveRequests *int `json:"active_requests,omitempty"`
MaxConcurrent *int `json:"max_concurrent,omitempty"`
MaxPerUser *int `json:"max_per_user,omitempty"`
QueuedRequests *int `json:"queued_requests,omitempty"`
TotalRequests *int `json:"total_requests,omitempty"`
UserActiveCount *map[string]int `json:"user_active_count,omitempty"`
}
// AIProviders defines model for AIProviders.
type AIProviders struct {
Levels *[]string `json:"levels,omitempty"`
Providers *[]struct {
Code *string `json:"code,omitempty"`
Models *[]struct {
Code *string `json:"code,omitempty"`
Name *string `json:"name,omitempty"`
} `json:"models,omitempty"`
Name *string `json:"name,omitempty"`
Url *string `json:"url,omitempty"`
} `json:"providers,omitempty"`
}
// APIKeyAvailabilityResponse defines model for APIKeyAvailabilityResponse.
type APIKeyAvailabilityResponse struct {
// HasApiKey Whether the user has a saved API key for this provider
HasApiKey bool `json:"has_api_key"`
}
// AggregatedVersion defines model for AggregatedVersion.
type AggregatedVersion struct {
Backend ServiceVersion `json:"backend"`
Worker AggregatedVersion_Worker `json:"worker"`
}
// AggregatedVersionWorker1 defines model for .
type AggregatedVersionWorker1 struct {
// Error Error message when worker is unavailable
Error string `json:"error"`
}
// AggregatedVersion_Worker defines model for AggregatedVersion.Worker.
type AggregatedVersion_Worker struct {
union json.RawMessage
}
// AnswerRequest defines model for AnswerRequest.
type AnswerRequest struct {
// QuestionId ID of the question being answered
QuestionId int64 `json:"question_id"`
// ResponseTimeMs Response time in milliseconds (0-5 minutes)
ResponseTimeMs *int32 `json:"response_time_ms,omitempty"`
// UserAnswerIndex Index of the user's selected answer in the original options array (0-based)
UserAnswerIndex int `json:"user_answer_index"`
}
// AnswerResponse defines model for AnswerResponse.
type AnswerResponse struct {
// CorrectAnswerIndex Index of the correct answer in the options array (0-based)
CorrectAnswerIndex *int `json:"correct_answer_index,omitempty"`
Explanation *string `json:"explanation,omitempty"`
IsCorrect *bool `json:"is_correct,omitempty"`
NextDifficulty *string `json:"next_difficulty,omitempty"`
// UserAnswer The answer selected by the user
UserAnswer *string `json:"user_answer,omitempty"`
// UserAnswerIndex Index of the user's selected answer in the original options array (0-based)
UserAnswerIndex *int `json:"user_answer_index,omitempty"`
}
// AuthStatusResponse defines model for AuthStatusResponse.
type AuthStatusResponse struct {
// Authenticated Whether the user is currently authenticated
Authenticated bool `json:"authenticated"`
User User `json:"user"`
}
// ChatMessage defines model for ChatMessage.
type ChatMessage struct {
// Content The message content
Content string `json:"content"`
// Role The role of the message sender
Role ChatMessageRole `json:"role"`
}
// ChatMessageRole The role of the message sender
type ChatMessageRole string
// DailyProgress defines model for DailyProgress.
type DailyProgress struct {
// Completed Number of completed questions
Completed int `json:"completed"`
// Date Date for the progress report (YYYY-MM-DD)
Date openapi_types.Date `json:"date"`
// Total Total number of questions assigned for the date
Total int `json:"total"`
}
// DailyQuestionHistory defines model for DailyQuestionHistory.
type DailyQuestionHistory struct {
// AssignmentDate RFC3339 timestamp of when the question was assigned in the user's timezone (includes offset)
AssignmentDate string `json:"assignment_date"`
// IsCompleted Whether the question was completed on this date
IsCompleted bool `json:"is_completed"`
// IsCorrect Whether the user's answer was correct (null if not attempted)
IsCorrect *bool `json:"is_correct"`
// SubmittedAt When the user submitted their answer
SubmittedAt *string `json:"submitted_at"`
}
// DailyQuestionWithDetails defines model for DailyQuestionWithDetails.
type DailyQuestionWithDetails struct {
// AssignmentDate Date-only assignment (YYYY-MM-DD) representing the logical calendar day the question was assigned (no timezone offset)
AssignmentDate openapi_types.Date `json:"assignment_date"`
// CompletedAt When the question was completed (if completed)
CompletedAt *string `json:"completed_at"`
// CreatedAt When the assignment was created
CreatedAt string `json:"created_at"`
// Id Daily question assignment ID
Id int64 `json:"id"`
// IsCompleted Whether the question has been completed
IsCompleted bool `json:"is_completed"`
Question Question `json:"question"`
// QuestionId Question ID
QuestionId int64 `json:"question_id"`
// SubmittedAt When the user submitted their answer
SubmittedAt *string `json:"submitted_at"`
// UserAnswerIndex The index of the answer option the user selected (0-based)
UserAnswerIndex *int `json:"user_answer_index"`
// UserCorrectCount Number of times this user answered this question correctly
UserCorrectCount *int64 `json:"user_correct_count,omitempty"`
// UserId User ID
UserId int64 `json:"user_id"`
// UserIncorrectCount Number of times this user answered this question incorrectly
UserIncorrectCount *int64 `json:"user_incorrect_count,omitempty"`
// UserShownCount Number of times this question was shown to this user in Daily view
UserShownCount *int64 `json:"user_shown_count,omitempty"`
// UserTotalResponses Number of times this user answered this question
UserTotalResponses *int64 `json:"user_total_responses,omitempty"`
}
// DashboardResponse defines model for DashboardResponse.
type DashboardResponse struct {
AiConcurrencyStats *AIConcurrencyStats `json:"ai_concurrency_stats,omitempty"`
QuestionStats *QuestionStats `json:"question_stats,omitempty"`
Users *[]DashboardUser `json:"users,omitempty"`
WorkerBaseUrl *string `json:"worker_base_url,omitempty"`
WorkerHealth *WorkerHealth `json:"worker_health,omitempty"`
WorkerPort *string `json:"worker_port,omitempty"`
}
// DashboardUser defines model for DashboardUser.
type DashboardUser struct {
Progress *UserProgress `json:"progress,omitempty"`
QuestionStats *UserQuestionStats `json:"question_stats,omitempty"`
User *UserProfile `json:"user,omitempty"`
}
// ErrorResponse defines model for ErrorResponse.
type ErrorResponse struct {
Details *string `json:"details,omitempty"`
Error *string `json:"error,omitempty"`
}
// ForceSendNotificationResponse defines model for ForceSendNotificationResponse.
type ForceSendNotificationResponse struct {
Message *string `json:"message,omitempty"`
Notification *struct {
Status *string `json:"status,omitempty"`
Subject *string `json:"subject,omitempty"`
Type *string `json:"type,omitempty"`
} `json:"notification,omitempty"`
User *struct {
Email *string `json:"email,omitempty"`
Id *int64 `json:"id,omitempty"`
Username *string `json:"username,omitempty"`
} `json:"user,omitempty"`
}
// GeneratingResponse defines model for GeneratingResponse.
type GeneratingResponse struct {
// AiModel User's preferred AI model
AiModel *string `json:"ai_model,omitempty"`
// ApiKey User's API key for the selected provider (write-only)
ApiKey *string `json:"api_key,omitempty"`
Message *string `json:"message,omitempty"`
Status *string `json:"status,omitempty"`
}
// GenerationFocus defines model for GenerationFocus.
type GenerationFocus struct {
// CurrentGenerationModel The AI model currently being used for generation
CurrentGenerationModel *string `json:"current_generation_model,omitempty"`
// GenerationRate Average number of questions generated per minute
GenerationRate *float32 `json:"generation_rate,omitempty"`
// LastGenerationTime Timestamp of the last time a question was generated
LastGenerationTime *string `json:"last_generation_time,omitempty"`
}
// GenerationIntelligence defines model for GenerationIntelligence.
type GenerationIntelligence struct {
GapAnalysis *[]map[string]interface{} `json:"gapAnalysis,omitempty"`
GenerationSuggestions *[]map[string]interface{} `json:"generationSuggestions,omitempty"`
}
// GoogleOAuthLoginResponse defines model for GoogleOAuthLoginResponse.
type GoogleOAuthLoginResponse struct {
// AuthUrl The Google OAuth authorization URL to redirect the user to
AuthUrl string `json:"auth_url"`
}
// Language Learning language (dynamic). Allowed values come from config.yaml language_levels keys.
type Language = string
// LanguagesResponse Array of available learning languages
type LanguagesResponse = []string
// Level Proficiency level (dynamic). Allowed values depend on the selected language and are sourced from config.yaml (e.g., CEFR A1âC2, JLPT N5âN1, HSK1âHSK6).
type Level = string
// LevelsResponse defines model for LevelsResponse.
type LevelsResponse struct {
// LevelDescriptions Mapping from level code to short label (e.g. Beginner, Intermediate)
LevelDescriptions map[string]string `json:"level_descriptions"`
// Levels Array of available language proficiency levels
Levels []string `json:"levels"`
}
// LoginRequest defines model for LoginRequest.
type LoginRequest struct {
// Password Password (minimum 8 characters)
Password string `json:"password"`
// Username Username (1-100 characters, alphanumeric + underscore + email characters, cannot be empty or whitespace-only)
Username string `json:"username"`
}
// LoginResponse defines model for LoginResponse.
type LoginResponse struct {
Message *string `json:"message,omitempty"`
// RedirectUri Redirect URI for OAuth flows (optional)
RedirectUri *string `json:"redirect_uri,omitempty"`
Success *bool `json:"success,omitempty"`
User *User `json:"user,omitempty"`
}
// NotificationError defines model for NotificationError.
type NotificationError struct {
// EmailAddress Email address that was being used
EmailAddress *string `json:"email_address"`
// ErrorMessage Detailed error message
ErrorMessage *string `json:"error_message,omitempty"`
// ErrorType Type of error that occurred
ErrorType *NotificationErrorErrorType `json:"error_type,omitempty"`
Id *int64 `json:"id,omitempty"`
// NotificationType Type of notification that failed
NotificationType *NotificationErrorNotificationType `json:"notification_type,omitempty"`
// OccurredAt When the error occurred
OccurredAt *string `json:"occurred_at,omitempty"`
// ResolutionNotes Notes about how the error was resolved
ResolutionNotes *string `json:"resolution_notes"`
// ResolvedAt When the error was resolved
ResolvedAt *string `json:"resolved_at"`
UserId *int64 `json:"user_id"`
// Username Username of the user (if available)
Username *string `json:"username,omitempty"`
}
// NotificationErrorErrorType Type of error that occurred
type NotificationErrorErrorType string
// NotificationErrorNotificationType Type of notification that failed
type NotificationErrorNotificationType string
// NotificationErrorStats defines model for NotificationErrorStats.
type NotificationErrorStats struct {
// ErrorsByNotificationType Breakdown of errors by notification type
ErrorsByNotificationType *map[string]int `json:"errors_by_notification_type,omitempty"`
// ErrorsByType Breakdown of errors by type
ErrorsByType *map[string]int `json:"errors_by_type,omitempty"`
// TotalErrors Total number of errors
TotalErrors *int `json:"total_errors,omitempty"`
// UnresolvedErrors Number of unresolved errors
UnresolvedErrors *int `json:"unresolved_errors,omitempty"`
}
// NotificationStats defines model for NotificationStats.
type NotificationStats struct {
// NotificationsByType Breakdown of notifications by type
NotificationsByType *map[string]int `json:"notifications_by_type,omitempty"`
// SentThisWeek Number of notifications sent this week
SentThisWeek *int `json:"sent_this_week,omitempty"`
// SentToday Number of notifications sent today
SentToday *int `json:"sent_today,omitempty"`
// SuccessRate Success rate as a percentage (0-1)
SuccessRate *float32 `json:"success_rate,omitempty"`
// TotalFailed Total number of notifications that failed
TotalFailed *int `json:"total_failed,omitempty"`
// TotalSent Total number of notifications sent
TotalSent *int `json:"total_sent,omitempty"`
}
// PaginationInfo defines model for PaginationInfo.
type PaginationInfo struct {
// Page Current page number
Page int `json:"page"`
// PageSize Number of items per page
PageSize int `json:"page_size"`
// Total Total number of items
Total int `json:"total"`
// TotalPages Total number of pages
TotalPages int `json:"total_pages"`
}
// PasswordResetRequest defines model for PasswordResetRequest.
type PasswordResetRequest struct {
// NewPassword New password (minimum 8 characters)
NewPassword string `json:"new_password"`
}
// PerformanceMetrics defines model for PerformanceMetrics.
type PerformanceMetrics struct {
AverageResponseTimeMs *float32 `json:"average_response_time_ms,omitempty"`
CorrectAttempts *int `json:"correct_attempts,omitempty"`
LastUpdated *string `json:"last_updated,omitempty"`
TotalAttempts *int `json:"total_attempts,omitempty"`
}
// PriorityInsights defines model for PriorityInsights.
type PriorityInsights struct {
// HighPriorityQuestions Number of high-priority questions
HighPriorityQuestions *int `json:"high_priority_questions,omitempty"`
// LowPriorityQuestions Number of low-priority questions
LowPriorityQuestions *int `json:"low_priority_questions,omitempty"`
// MediumPriorityQuestions Number of medium-priority questions
MediumPriorityQuestions *int `json:"medium_priority_questions,omitempty"`
// TotalQuestionsInQueue Total number of questions waiting to be processed
TotalQuestionsInQueue *int `json:"total_questions_in_queue,omitempty"`
}
// Question defines model for Question.
type Question struct {
// ConfidenceLevel Confidence level when question was marked as known (1-5)
ConfidenceLevel *int `json:"confidence_level,omitempty"`
// Content All question types now use multiple choice format with 4 options
Content *QuestionContent `json:"content,omitempty"`
// CorrectAnswer Index of the correct answer in the options array (0-based)
CorrectAnswer *int `json:"correct_answer,omitempty"`
// CorrectCount Number of times this question was answered correctly
CorrectCount *int `json:"correct_count,omitempty"`
CreatedAt *string `json:"created_at,omitempty"`
// DifficultyModifier Difficulty modifier for the question (e.g., basic, intermediate)
DifficultyModifier *string `json:"difficulty_modifier,omitempty"`
DifficultyScore *float32 `json:"difficulty_score,omitempty"`
Explanation *string `json:"explanation,omitempty"`
// GrammarFocus Grammar focus area for the question (e.g., present_perfect, conditionals)
GrammarFocus *string `json:"grammar_focus,omitempty"`
Id *int64 `json:"id,omitempty"`
// IncorrectCount Number of times this question was answered incorrectly
IncorrectCount *int `json:"incorrect_count,omitempty"`
// Language Learning language (dynamic). Allowed values come from config.yaml language_levels keys.
Language *Language `json:"language,omitempty"`
// Level Proficiency level (dynamic). Allowed values depend on the selected language and are sourced from config.yaml (e.g., CEFR A1âC2, JLPT N5âN1, HSK1âHSK6).
Level *Level `json:"level,omitempty"`
// Reporters Comma-separated list of usernames who reported this question
Reporters *string `json:"reporters,omitempty"`
// Scenario Scenario context for the question (e.g., at_the_airport, in_a_restaurant)
Scenario *string `json:"scenario,omitempty"`
Status *QuestionStatus `json:"status,omitempty"`
// StyleModifier Style modifier for the question (e.g., conversational, formal)
StyleModifier *string `json:"style_modifier,omitempty"`
// TimeContext Time context for the question (e.g., morning_routine, workday)
TimeContext *string `json:"time_context,omitempty"`
// TopicCategory General topic category for question context (e.g., daily_life, travel, work)
TopicCategory *string `json:"topic_category,omitempty"`
// TotalResponses Total number of responses to this question (used for 'Shown' in the UI)
TotalResponses *int `json:"total_responses,omitempty"`
Type *QuestionType `json:"type,omitempty"`
// UserCount Number of users assigned to this question
UserCount *int `json:"user_count,omitempty"`
// VocabularyDomain Vocabulary domain for the question (e.g., food_and_dining, transportation)
VocabularyDomain *string `json:"vocabulary_domain,omitempty"`
}
// QuestionContent All question types now use multiple choice format with 4 options
type QuestionContent struct {
// Hint Optional hint for fill-in-blank questions
Hint *string `json:"hint,omitempty"`
Options []string `json:"options"`
// Passage Only present for reading comprehension questions
Passage *string `json:"passage,omitempty"`
Question string `json:"question"`
// Sentence Only present for vocabulary questions (context sentence)
Sentence *string `json:"sentence,omitempty"`
}
// QuestionStats defines model for QuestionStats.
type QuestionStats struct {
// QuestionsByLanguage Breakdown of questions by language
QuestionsByLanguage *map[string]int `json:"questions_by_language,omitempty"`
// QuestionsByLevel Breakdown of questions by level
QuestionsByLevel *map[string]int `json:"questions_by_level,omitempty"`
// QuestionsByType Breakdown of questions by type
QuestionsByType *map[string]int `json:"questions_by_type,omitempty"`
// TotalQuestions Total number of questions
TotalQuestions *int `json:"total_questions,omitempty"`
// TotalResponses Total number of responses
TotalResponses *int `json:"total_responses,omitempty"`
}
// QuestionStatus defines model for QuestionStatus.
type QuestionStatus string
// QuestionType defines model for QuestionType.
type QuestionType string
// QuizChatRequest defines model for QuizChatRequest.
type QuizChatRequest struct {
AnswerContext *AnswerResponse `json:"answer_context,omitempty"`
// ConversationHistory Previous messages in the conversation
ConversationHistory *[]ChatMessage `json:"conversation_history,omitempty"`
Question Question `json:"question"`
// UserMessage The user's message to the AI tutor.
UserMessage string `json:"user_message"`
}
// Role defines model for Role.
type Role struct {
// CreatedAt When the role was created
CreatedAt string `json:"created_at"`
// Description Role description
Description string `json:"description"`
// Id Role ID
Id int64 `json:"id"`
// Name Role name (e.g., "user", "admin")
Name string `json:"name"`
// UpdatedAt When the role was last updated
UpdatedAt string `json:"updated_at"`
}
// SentNotification defines model for SentNotification.
type SentNotification struct {
// EmailAddress Email address the notification was sent to
EmailAddress *string `json:"email_address,omitempty"`
// ErrorMessage Error message if the notification failed
ErrorMessage *string `json:"error_message"`
Id *int64 `json:"id,omitempty"`
// NotificationType Type of notification
NotificationType *SentNotificationNotificationType `json:"notification_type,omitempty"`
// RetryCount Number of times the notification was retried
RetryCount *int `json:"retry_count,omitempty"`
// SentAt When the notification was sent
SentAt *string `json:"sent_at,omitempty"`
// Status Status of the notification
Status *SentNotificationStatus `json:"status,omitempty"`
// Subject Subject line of the email
Subject *string `json:"subject,omitempty"`
// TemplateName Template used for the notification
TemplateName *string `json:"template_name,omitempty"`
UserId *int64 `json:"user_id,omitempty"`
// Username Username of the user
Username *string `json:"username,omitempty"`
}
// SentNotificationNotificationType Type of notification
type SentNotificationNotificationType string
// SentNotificationStatus Status of the notification
type SentNotificationStatus string
// ServiceVersion defines model for ServiceVersion.
type ServiceVersion struct {
// BuildTime Build timestamp (ISO8601)
BuildTime string `json:"buildTime"`
// Commit Git commit hash
Commit string `json:"commit"`
// Service Service name (e.g., 'backend', 'worker')
Service string `json:"service"`
// Version Version string (e.g., git tag or 'dev')
Version string `json:"version"`
}
// SuccessResponse defines model for SuccessResponse.
type SuccessResponse struct {
Message *string `json:"message,omitempty"`
Success bool `json:"success"`
}
// SystemHealthAnalytics defines model for SystemHealthAnalytics.
type SystemHealthAnalytics struct {
BackgroundJobs *map[string]interface{} `json:"backgroundJobs,omitempty"`
Performance *map[string]interface{} `json:"performance,omitempty"`
}
// TTSRequest defines model for TTSRequest.
type TTSRequest struct {
// Input The text to convert to speech
Input string `json:"input"`
// Model The TTS model to use
Model *string `json:"model,omitempty"`
// StreamFormat The format for streaming audio data
StreamFormat *TTSRequestStreamFormat `json:"stream_format,omitempty"`
// Voice The voice to use for speech generation
Voice *string `json:"voice,omitempty"`
}
// TTSRequestStreamFormat The format for streaming audio data
type TTSRequestStreamFormat string
// TTSResponse defines model for TTSResponse.
type TTSResponse struct {
// Audio Base64 encoded audio chunk (for type=audio)
Audio *string `json:"audio,omitempty"`
// Error Error message (for type=error)
Error *string `json:"error,omitempty"`
// Type The type of SSE event
Type *TTSResponseType `json:"type,omitempty"`
// Usage Usage statistics (for type=usage)
Usage *struct {
// InputTokens Number of input tokens processed
InputTokens *int `json:"input_tokens,omitempty"`
// OutputTokens Number of output tokens generated
OutputTokens *int `json:"output_tokens,omitempty"`
// TotalTokens Total tokens used
TotalTokens *int `json:"total_tokens,omitempty"`
} `json:"usage,omitempty"`
}
// TTSResponseType The type of SSE event
type TTSResponseType string
// TestAIRequest defines model for TestAIRequest.
type TestAIRequest struct {
// ApiKey API key for the provider. If not provided, the server will try to use a saved key.
ApiKey *string `json:"api_key"`
// Model AI model code (e.g., "llama3", "gpt-4")
Model string `json:"model"`
// Provider AI provider code (e.g., "ollama", "openai")
Provider string `json:"provider"`
}
// User defines model for User.
type User struct {
// AiEnabled Whether AI features are enabled for this user
AiEnabled *bool `json:"ai_enabled"`
AiModel *string `json:"ai_model"`
AiProvider *string `json:"ai_provider"`
CreatedAt *string `json:"created_at,omitempty"`
CurrentLevel *string `json:"current_level"`
Email *string `json:"email"`
// HasApiKey Whether the user has a valid API key saved for their current AI provider
HasApiKey *bool `json:"has_api_key,omitempty"`
Id *int64 `json:"id,omitempty"`
// IsPaused Whether the user is paused (question generation disabled)
IsPaused *bool `json:"is_paused,omitempty"`
LastActive *string `json:"last_active"`
PreferredLanguage *string `json:"preferred_language"`
// Roles List of roles assigned to the user
Roles *[]Role `json:"roles,omitempty"`
Timezone *string `json:"timezone"`
// Username Username (1-100 characters, alphanumeric + underscore + email characters, cannot be empty or whitespace-only)
Username *string `json:"username,omitempty"`
}
// UserCreateRequest defines model for UserCreateRequest.
type UserCreateRequest struct {
// AiEnabled Whether AI features are enabled for this user
AiEnabled *bool `json:"ai_enabled,omitempty"`
// CurrentLevel Current proficiency level
CurrentLevel *string `json:"current_level,omitempty"`
// Email Email address
Email *openapi_types.Email `json:"email,omitempty"`
// Password Password (minimum 8 characters)
Password string `json:"password"`
// PreferredLanguage Preferred learning language
PreferredLanguage *string `json:"preferred_language,omitempty"`
// Timezone Timezone (e.g., "UTC", "America/New_York")
Timezone *string `json:"timezone,omitempty"`
// Username Username (1-100 characters, alphanumeric + underscore + email characters, cannot be empty or whitespace-only)
Username string `json:"username"`
}
// UserLearningPreferences defines model for UserLearningPreferences.
type UserLearningPreferences struct {
// DailyGoal User-configurable number of daily questions
DailyGoal *int `json:"daily_goal,omitempty"`
// DailyReminderEnabled Whether to receive daily reminder emails
DailyReminderEnabled bool `json:"daily_reminder_enabled"`
// FocusOnWeakAreas Whether to focus on weak areas
FocusOnWeakAreas bool `json:"focus_on_weak_areas"`
// FreshQuestionRatio Ratio of fresh (never seen) questions to show (0-1)
FreshQuestionRatio float32 `json:"fresh_question_ratio"`
// KnownQuestionPenalty Penalty multiplier for questions marked as known (0-1)
KnownQuestionPenalty float32 `json:"known_question_penalty"`
// ReviewIntervalDays Days between reviews of known questions
ReviewIntervalDays int `json:"review_interval_days"`
// TtsVoice Preferred TTS voice (e.g., it-IT-IsabellaNeural)
TtsVoice *string `json:"tts_voice,omitempty"`
// WeakAreaBoost Multiplier for weak area questions
WeakAreaBoost float32 `json:"weak_area_boost"`
}
// UserPerformanceAnalytics defines model for UserPerformanceAnalytics.
type UserPerformanceAnalytics struct {
LearningPreferences *map[string]interface{} `json:"learningPreferences,omitempty"`
WeakAreas *[]map[string]interface{} `json:"weakAreas,omitempty"`
}
// UserProfile defines model for UserProfile.
type UserProfile struct {
// AiEnabled Whether AI features are enabled for this user
AiEnabled *bool `json:"ai_enabled"`
CreatedAt *string `json:"created_at,omitempty"`
CurrentLevel *string `json:"current_level,omitempty"`
Email *string `json:"email"`
Id *int64 `json:"id,omitempty"`
// IsPaused Whether the user is paused (question generation disabled)
IsPaused *bool `json:"is_paused,omitempty"`
LastActive *string `json:"last_active"`
PreferredLanguage *string `json:"preferred_language"`
Timezone *string `json:"timezone"`
UpdatedAt *string `json:"updated_at,omitempty"`
// Username Username (1-100 characters, alphanumeric + underscore + email characters, cannot be empty or whitespace-only)
Username *string `json:"username,omitempty"`
}
// UserProgress defines model for UserProgress.
type UserProgress struct {
AccuracyRate *float32 `json:"accuracy_rate,omitempty"`
CorrectAnswers *int `json:"correct_answers,omitempty"`
// CurrentLevel Proficiency level (dynamic). Allowed values depend on the selected language and are sourced from config.yaml (e.g., CEFR A1âC2, JLPT N5âN1, HSK1âHSK6).
CurrentLevel *Level `json:"current_level,omitempty"`
// GapAnalysis Analysis of learning gaps and areas needing attention
GapAnalysis *map[string]interface{} `json:"gap_analysis,omitempty"`
GenerationFocus *GenerationFocus `json:"generation_focus,omitempty"`
// HighPriorityTopics Topics that have high priority scores for the user
HighPriorityTopics *[]string `json:"high_priority_topics,omitempty"`
LearningPreferences *UserLearningPreferences `json:"learning_preferences,omitempty"`
PerformanceByTopic *map[string]PerformanceMetrics `json:"performance_by_topic,omitempty"`
// PriorityDistribution Distribution of question priorities (high, medium, low counts)
PriorityDistribution *map[string]int `json:"priority_distribution,omitempty"`
PriorityInsights *PriorityInsights `json:"priority_insights,omitempty"`
RecentActivity *[]UserResponse `json:"recent_activity,omitempty"`
// SuggestedLevel Proficiency level (dynamic). Allowed values depend on the selected language and are sourced from config.yaml (e.g., CEFR A1âC2, JLPT N5âN1, HSK1âHSK6).
SuggestedLevel *Level `json:"suggested_level,omitempty"`
TotalQuestions *int `json:"total_questions,omitempty"`
WeakAreas *[]string `json:"weak_areas,omitempty"`
WorkerStatus *WorkerStatus `json:"worker_status,omitempty"`
}
// UserQuestionStats defines model for UserQuestionStats.
type UserQuestionStats struct {
AccuracyByLevel *map[string]float32 `json:"accuracy_by_level,omitempty"`
AccuracyByType *map[string]float32 `json:"accuracy_by_type,omitempty"`
AnsweredByLevel *map[string]int `json:"answered_by_level,omitempty"`
AnsweredByType *map[string]int `json:"answered_by_type,omitempty"`
AvailableByLevel *map[string]int `json:"available_by_level,omitempty"`
AvailableByType *map[string]int `json:"available_by_type,omitempty"`
TotalAnswered *int `json:"total_answered,omitempty"`
UserId *int64 `json:"user_id,omitempty"`
}
// UserResponse defines model for UserResponse.
type UserResponse struct {
CreatedAt *string `json:"created_at,omitempty"`
IsCorrect *bool `json:"is_correct,omitempty"`
QuestionId *int64 `json:"question_id,omitempty"`
}
// UserSettings defines model for UserSettings.
type UserSettings struct {
// AiEnabled Whether AI features are enabled for this user
AiEnabled *bool `json:"ai_enabled,omitempty"`
AiModel *string `json:"ai_model,omitempty"`
AiProvider *string `json:"ai_provider,omitempty"`
// ApiKey API key for AI provider (write-only)
ApiKey *string `json:"api_key,omitempty"`
// Language Learning language (dynamic). Allowed values come from config.yaml language_levels keys.
Language *Language `json:"language,omitempty"`
// Level Proficiency level (dynamic). Allowed values depend on the selected language and are sourced from config.yaml (e.g., CEFR A1âC2, JLPT N5âN1, HSK1âHSK6).
Level *Level `json:"level,omitempty"`
union json.RawMessage
}
// UserSettings0 defines model for .
type UserSettings0 = interface{}
// UserSettings1 defines model for .
type UserSettings1 = interface{}
// UserUpdateRequest defines model for UserUpdateRequest.
type UserUpdateRequest struct {
// AiEnabled Whether AI features are enabled for this user
AiEnabled *bool `json:"ai_enabled,omitempty"`
// AiModel AI model code
AiModel *string `json:"ai_model,omitempty"`
// AiProvider AI provider code
AiProvider *string `json:"ai_provider,omitempty"`
// ApiKey API key for AI provider (write-only)
ApiKey *string `json:"api_key,omitempty"`
// CurrentLevel Current proficiency level
CurrentLevel *string `json:"current_level,omitempty"`
// Email Email address
Email *openapi_types.Email `json:"email,omitempty"`
// PreferredLanguage Preferred learning language
PreferredLanguage *string `json:"preferred_language,omitempty"`
// SelectedRoles Array of role names to assign to the user
SelectedRoles *[]string `json:"selectedRoles,omitempty"`
// Timezone Timezone (e.g., "UTC", "America/New_York")
Timezone *string `json:"timezone,omitempty"`
// Username Username (1-100 characters, alphanumeric + underscore + email characters, cannot be empty or whitespace-only)
Username *string `json:"username,omitempty"`
union json.RawMessage
}
// UserUpdateRequest0 defines model for .
type UserUpdateRequest0 = interface{}
// UserUpdateRequest1 defines model for .
type UserUpdateRequest1 = interface{}
// WorkerHealth defines model for WorkerHealth.
type WorkerHealth struct {
GlobalPaused *bool `json:"global_paused,omitempty"`
HealthyCount *int `json:"healthy_count,omitempty"`
TotalCount *int `json:"total_count,omitempty"`
WorkerInstances *[]struct {
Healthy *bool `json:"healthy,omitempty"`
IsPaused *bool `json:"is_paused,omitempty"`
IsRunning *bool `json:"is_running,omitempty"`
LastHeartbeat *struct {
Time *string `json:"Time,omitempty"`
Valid *bool `json:"Valid,omitempty"`
} `json:"last_heartbeat,omitempty"`
TotalQuestionsGenerated *int `json:"total_questions_generated,omitempty"`
TotalRuns *int `json:"total_runs,omitempty"`
WorkerInstance *string `json:"worker_instance,omitempty"`
} `json:"worker_instances,omitempty"`
}
// WorkerStatus defines model for WorkerStatus.
type WorkerStatus struct {
// ErrorMessage Error message if the worker is in an error state
ErrorMessage *string `json:"error_message"`
// LastHeartbeat Timestamp of the last heartbeat from the worker
LastHeartbeat *string `json:"last_heartbeat,omitempty"`
// Status Current status of the worker
Status *WorkerStatusStatus `json:"status,omitempty"`
}
// WorkerStatusStatus Current status of the worker
type WorkerStatusStatus string
// WorkerStatusResponse defines model for WorkerStatusResponse.
type WorkerStatusResponse struct {
// ErrorMessage Error message if worker has errors
ErrorMessage string `json:"error_message"`
// GlobalPaused Whether the worker is globally paused
GlobalPaused bool `json:"global_paused"`
// HasErrors Whether the worker has encountered errors
HasErrors bool `json:"has_errors"`
// HealthyWorkers Number of healthy worker instances
HealthyWorkers int `json:"healthy_workers"`
// LastErrorDetails Detailed error information if any
LastErrorDetails string `json:"last_error_details"`
// TotalWorkers Total number of worker instances
TotalWorkers int `json:"total_workers"`
// UserPaused Whether the user's question generation is paused
UserPaused bool `json:"user_paused"`
// WorkerRunning Whether the worker is currently running
WorkerRunning bool `json:"worker_running"`
}
// GetV1AdminBackendQuestionsParams defines parameters for GetV1AdminBackendQuestions.
type GetV1AdminBackendQuestionsParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of questions per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// Search Search term for question content
Search *string `form:"search,omitempty" json:"search,omitempty"`
// Type Filter by question type
Type *QuestionType `form:"type,omitempty" json:"type,omitempty"`
// Status Filter by question status
Status *QuestionStatus `form:"status,omitempty" json:"status,omitempty"`
// Language Filter by language
Language *Language `form:"language,omitempty" json:"language,omitempty"`
// Level Filter by level
Level *Level `form:"level,omitempty" json:"level,omitempty"`
// UserId Filter by user ID (optional)
UserId *int64 `form:"user_id,omitempty" json:"user_id,omitempty"`
}
// GetV1AdminBackendQuestionsPaginatedParams defines parameters for GetV1AdminBackendQuestionsPaginated.
type GetV1AdminBackendQuestionsPaginatedParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of questions per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// Search Search term for question content
Search *string `form:"search,omitempty" json:"search,omitempty"`
// Type Filter by question type
Type *QuestionType `form:"type,omitempty" json:"type,omitempty"`
// Status Filter by question status
Status *QuestionStatus `form:"status,omitempty" json:"status,omitempty"`
// Language Filter by language
Language *Language `form:"language,omitempty" json:"language,omitempty"`
// Level Filter by level
Level *Level `form:"level,omitempty" json:"level,omitempty"`
// UserId Filter by user ID (optional)
UserId *int64 `form:"user_id,omitempty" json:"user_id,omitempty"`
}
// PutV1AdminBackendQuestionsIdJSONBody defines parameters for PutV1AdminBackendQuestionsId.
type PutV1AdminBackendQuestionsIdJSONBody struct {
// Content Updated question content
Content map[string]interface{} `json:"content"`
// CorrectAnswer Index of the correct answer
CorrectAnswer *int `json:"correct_answer,omitempty"`
// Explanation Explanation for the correct answer
Explanation string `json:"explanation"`
}
// PostV1AdminBackendQuestionsIdAiFixJSONBody defines parameters for PostV1AdminBackendQuestionsIdAiFix.
type PostV1AdminBackendQuestionsIdAiFixJSONBody struct {
AdditionalContext *string `json:"additional_context,omitempty"`
}
// PostV1AdminBackendQuestionsIdAssignUsersJSONBody defines parameters for PostV1AdminBackendQuestionsIdAssignUsers.
type PostV1AdminBackendQuestionsIdAssignUsersJSONBody struct {
// UserIds Array of user IDs to assign to the question
UserIds []int64 `json:"user_ids"`
}
// PostV1AdminBackendQuestionsIdUnassignUsersJSONBody defines parameters for PostV1AdminBackendQuestionsIdUnassignUsers.
type PostV1AdminBackendQuestionsIdUnassignUsersJSONBody struct {
// UserIds Array of user IDs to unassign from the question
UserIds []int64 `json:"user_ids"`
}
// GetV1AdminBackendReportedQuestionsParams defines parameters for GetV1AdminBackendReportedQuestions.
type GetV1AdminBackendReportedQuestionsParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of questions per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// Search Search term for question content
Search *string `form:"search,omitempty" json:"search,omitempty"`
// Type Filter by question type
Type *QuestionType `form:"type,omitempty" json:"type,omitempty"`
// Language Filter by language
Language *Language `form:"language,omitempty" json:"language,omitempty"`
// Level Filter by level
Level *Level `form:"level,omitempty" json:"level,omitempty"`
}
// PostV1AdminBackendUserzJSONBody defines parameters for PostV1AdminBackendUserz.
type PostV1AdminBackendUserzJSONBody struct {
// AiEnabled Whether AI is enabled for this user
AiEnabled *bool `json:"ai_enabled,omitempty"`
// AiModel AI model preference
AiModel *string `json:"ai_model,omitempty"`
// AiProvider AI provider preference
AiProvider *string `json:"ai_provider,omitempty"`
// Email Email address for the new user
Email openapi_types.Email `json:"email"`
// Language Preferred language for the user
Language *string `json:"language,omitempty"`
// Level Current level for the user
Level *string `json:"level,omitempty"`
// Password Password for the new user
Password string `json:"password"`
// Username Username (1-100 characters, alphanumeric + underscore + email characters, cannot be empty or whitespace-only)
Username string `json:"username"`
}
// GetV1AdminBackendUserzPaginatedParams defines parameters for GetV1AdminBackendUserzPaginated.
type GetV1AdminBackendUserzPaginatedParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of users per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// Search Search term for username or email
Search *string `form:"search,omitempty" json:"search,omitempty"`
// Language Filter by preferred language
Language *Language `form:"language,omitempty" json:"language,omitempty"`
// Level Filter by current level
Level *Level `form:"level,omitempty" json:"level,omitempty"`
// AiProvider Filter by AI provider
AiProvider *string `form:"ai_provider,omitempty" json:"ai_provider,omitempty"`
// AiModel Filter by AI model
AiModel *string `form:"ai_model,omitempty" json:"ai_model,omitempty"`
// AiEnabled Filter by AI enabled status
AiEnabled *GetV1AdminBackendUserzPaginatedParamsAiEnabled `form:"ai_enabled,omitempty" json:"ai_enabled,omitempty"`
// Active Filter by active status (active within 7 days)
Active *GetV1AdminBackendUserzPaginatedParamsActive `form:"active,omitempty" json:"active,omitempty"`
}
// GetV1AdminBackendUserzPaginatedParamsAiEnabled defines parameters for GetV1AdminBackendUserzPaginated.
type GetV1AdminBackendUserzPaginatedParamsAiEnabled string
// GetV1AdminBackendUserzPaginatedParamsActive defines parameters for GetV1AdminBackendUserzPaginated.
type GetV1AdminBackendUserzPaginatedParamsActive string
// PostV1AdminBackendUserzIdRolesJSONBody defines parameters for PostV1AdminBackendUserzIdRoles.
type PostV1AdminBackendUserzIdRolesJSONBody struct {
// RoleId Role ID to assign
RoleId int64 `json:"role_id"`
}
// GetV1AdminWorkerNotificationsErrorsParams defines parameters for GetV1AdminWorkerNotificationsErrors.
type GetV1AdminWorkerNotificationsErrorsParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of errors per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// ErrorType Filter by error type
ErrorType *GetV1AdminWorkerNotificationsErrorsParamsErrorType `form:"error_type,omitempty" json:"error_type,omitempty"`
// NotificationType Filter by notification type
NotificationType *GetV1AdminWorkerNotificationsErrorsParamsNotificationType `form:"notification_type,omitempty" json:"notification_type,omitempty"`
// Resolved Filter by resolution status
Resolved *GetV1AdminWorkerNotificationsErrorsParamsResolved `form:"resolved,omitempty" json:"resolved,omitempty"`
}
// GetV1AdminWorkerNotificationsErrorsParamsErrorType defines parameters for GetV1AdminWorkerNotificationsErrors.
type GetV1AdminWorkerNotificationsErrorsParamsErrorType string
// GetV1AdminWorkerNotificationsErrorsParamsNotificationType defines parameters for GetV1AdminWorkerNotificationsErrors.
type GetV1AdminWorkerNotificationsErrorsParamsNotificationType string
// GetV1AdminWorkerNotificationsErrorsParamsResolved defines parameters for GetV1AdminWorkerNotificationsErrors.
type GetV1AdminWorkerNotificationsErrorsParamsResolved string
// PostV1AdminWorkerNotificationsForceSendJSONBody defines parameters for PostV1AdminWorkerNotificationsForceSend.
type PostV1AdminWorkerNotificationsForceSendJSONBody struct {
// Username Username of the user to send notification to
Username string `json:"username"`
}
// GetV1AdminWorkerNotificationsSentParams defines parameters for GetV1AdminWorkerNotificationsSent.
type GetV1AdminWorkerNotificationsSentParams struct {
// Page Page number (1-based)
Page *int `form:"page,omitempty" json:"page,omitempty"`
// PageSize Number of notifications per page
PageSize *int `form:"page_size,omitempty" json:"page_size,omitempty"`
// NotificationType Filter by notification type
NotificationType *GetV1AdminWorkerNotificationsSentParamsNotificationType `form:"notification_type,omitempty" json:"notification_type,omitempty"`
// Status Filter by status
Status *GetV1AdminWorkerNotificationsSentParamsStatus `form:"status,omitempty" json:"status,omitempty"`
// SentAfter Filter notifications sent after this timestamp
SentAfter *string `form:"sent_after,omitempty" json:"sent_after,omitempty"`
// SentBefore Filter notifications sent before this timestamp
SentBefore *string `form:"sent_before,omitempty" json:"sent_before,omitempty"`
}
// GetV1AdminWorkerNotificationsSentParamsNotificationType defines parameters for GetV1AdminWorkerNotificationsSent.
type GetV1AdminWorkerNotificationsSentParamsNotificationType string
// GetV1AdminWorkerNotificationsSentParamsStatus defines parameters for GetV1AdminWorkerNotificationsSent.
type GetV1AdminWorkerNotificationsSentParamsStatus string
// PostV1AdminWorkerUsersPauseJSONBody defines parameters for PostV1AdminWorkerUsersPause.
type PostV1AdminWorkerUsersPauseJSONBody struct {
// UserId ID of the user to pause
UserId int `json:"user_id"`
}
// PostV1AdminWorkerUsersResumeJSONBody defines parameters for PostV1AdminWorkerUsersResume.
type PostV1AdminWorkerUsersResumeJSONBody struct {
// UserId ID of the user to resume
UserId int `json:"user_id"`
}
// GetV1AuthGoogleCallbackParams defines parameters for GetV1AuthGoogleCallback.
type GetV1AuthGoogleCallbackParams struct {
// Code Authorization code from Google
Code string `form:"code" json:"code"`
// State State parameter for CSRF protection
State *string `form:"state,omitempty" json:"state,omitempty"`
}
// PostV1DailyQuestionsDateAnswerQuestionIdJSONBody defines parameters for PostV1DailyQuestionsDateAnswerQuestionId.
type PostV1DailyQuestionsDateAnswerQuestionIdJSONBody struct {
// UserAnswerIndex Index of the user's selected answer (0-based)
UserAnswerIndex int `json:"user_answer_index"`
}
// GetV1QuizQuestionParams defines parameters for GetV1QuizQuestion.
type GetV1QuizQuestionParams struct {
// Language Preferred language for the question
Language *Language `form:"language,omitempty" json:"language,omitempty"`
// Level Difficulty level for the question
Level *Level `form:"level,omitempty" json:"level,omitempty"`
// Type Specific question type(s) to retrieve (comma-separated list). If multiple types are provided, the first valid type will be used.
Type *string `form:"type,omitempty" json:"type,omitempty"`
// ExcludeType Question type(s) to exclude from random selection (comma-separated list). Useful for filtering out specific question types from the general quiz.
ExcludeType *string `form:"exclude_type,omitempty" json:"exclude_type,omitempty"`
}
// PostV1QuizQuestionIdMarkKnownJSONBody defines parameters for PostV1QuizQuestionIdMarkKnown.
type PostV1QuizQuestionIdMarkKnownJSONBody struct {
// ConfidenceLevel User's confidence level (1-5, optional)
ConfidenceLevel *int `json:"confidence_level,omitempty"`
}
// PostV1QuizQuestionIdReportJSONBody defines parameters for PostV1QuizQuestionIdReport.
type PostV1QuizQuestionIdReportJSONBody struct {
// ReportReason Optional explanation for why the question is being reported
ReportReason *string `json:"report_reason,omitempty"`
}
// GetV1SettingsLevelsParams defines parameters for GetV1SettingsLevels.
type GetV1SettingsLevelsParams struct {
// Language Language to get levels for (optional - returns all levels if not specified)
Language *string `form:"language,omitempty" json:"language,omitempty"`
}
// PutV1AdminBackendQuestionsIdJSONRequestBody defines body for PutV1AdminBackendQuestionsId for application/json ContentType.
type PutV1AdminBackendQuestionsIdJSONRequestBody PutV1AdminBackendQuestionsIdJSONBody
// PostV1AdminBackendQuestionsIdAiFixJSONRequestBody defines body for PostV1AdminBackendQuestionsIdAiFix for application/json ContentType.
type PostV1AdminBackendQuestionsIdAiFixJSONRequestBody PostV1AdminBackendQuestionsIdAiFixJSONBody
// PostV1AdminBackendQuestionsIdAssignUsersJSONRequestBody defines body for PostV1AdminBackendQuestionsIdAssignUsers for application/json ContentType.
type PostV1AdminBackendQuestionsIdAssignUsersJSONRequestBody PostV1AdminBackendQuestionsIdAssignUsersJSONBody
// PostV1AdminBackendQuestionsIdUnassignUsersJSONRequestBody defines body for PostV1AdminBackendQuestionsIdUnassignUsers for application/json ContentType.
type PostV1AdminBackendQuestionsIdUnassignUsersJSONRequestBody PostV1AdminBackendQuestionsIdUnassignUsersJSONBody
// PostV1AdminBackendUserzJSONRequestBody defines body for PostV1AdminBackendUserz for application/json ContentType.
type PostV1AdminBackendUserzJSONRequestBody PostV1AdminBackendUserzJSONBody
// PutV1AdminBackendUserzIdJSONRequestBody defines body for PutV1AdminBackendUserzId for application/json ContentType.
type PutV1AdminBackendUserzIdJSONRequestBody = UserUpdateRequest
// PostV1AdminBackendUserzIdResetPasswordJSONRequestBody defines body for PostV1AdminBackendUserzIdResetPassword for application/json ContentType.
type PostV1AdminBackendUserzIdResetPasswordJSONRequestBody = PasswordResetRequest
// PostV1AdminBackendUserzIdRolesJSONRequestBody defines body for PostV1AdminBackendUserzIdRoles for application/json ContentType.
type PostV1AdminBackendUserzIdRolesJSONRequestBody PostV1AdminBackendUserzIdRolesJSONBody
// PostV1AdminWorkerNotificationsForceSendJSONRequestBody defines body for PostV1AdminWorkerNotificationsForceSend for application/json ContentType.
type PostV1AdminWorkerNotificationsForceSendJSONRequestBody PostV1AdminWorkerNotificationsForceSendJSONBody
// PostV1AdminWorkerUsersPauseJSONRequestBody defines body for PostV1AdminWorkerUsersPause for application/json ContentType.
type PostV1AdminWorkerUsersPauseJSONRequestBody PostV1AdminWorkerUsersPauseJSONBody
// PostV1AdminWorkerUsersResumeJSONRequestBody defines body for PostV1AdminWorkerUsersResume for application/json ContentType.
type PostV1AdminWorkerUsersResumeJSONRequestBody PostV1AdminWorkerUsersResumeJSONBody
// PostV1AudioSpeechJSONRequestBody defines body for PostV1AudioSpeech for application/json ContentType.
type PostV1AudioSpeechJSONRequestBody = TTSRequest
// PostV1AuthLoginJSONRequestBody defines body for PostV1AuthLogin for application/json ContentType.
type PostV1AuthLoginJSONRequestBody = LoginRequest
// PostV1AuthSignupJSONRequestBody defines body for PostV1AuthSignup for application/json ContentType.
type PostV1AuthSignupJSONRequestBody = UserCreateRequest
// PostV1DailyQuestionsDateAnswerQuestionIdJSONRequestBody defines body for PostV1DailyQuestionsDateAnswerQuestionId for application/json ContentType.
type PostV1DailyQuestionsDateAnswerQuestionIdJSONRequestBody PostV1DailyQuestionsDateAnswerQuestionIdJSONBody
// PutV1PreferencesLearningJSONRequestBody defines body for PutV1PreferencesLearning for application/json ContentType.
type PutV1PreferencesLearningJSONRequestBody = UserLearningPreferences
// PostV1QuizAnswerJSONRequestBody defines body for PostV1QuizAnswer for application/json ContentType.
type PostV1QuizAnswerJSONRequestBody = AnswerRequest
// PostV1QuizChatStreamJSONRequestBody defines body for PostV1QuizChatStream for application/json ContentType.
type PostV1QuizChatStreamJSONRequestBody = QuizChatRequest
// PostV1QuizQuestionIdMarkKnownJSONRequestBody defines body for PostV1QuizQuestionIdMarkKnown for application/json ContentType.
type PostV1QuizQuestionIdMarkKnownJSONRequestBody PostV1QuizQuestionIdMarkKnownJSONBody
// PostV1QuizQuestionIdReportJSONRequestBody defines body for PostV1QuizQuestionIdReport for application/json ContentType.
type PostV1QuizQuestionIdReportJSONRequestBody PostV1QuizQuestionIdReportJSONBody
// PutV1SettingsJSONRequestBody defines body for PutV1Settings for application/json ContentType.
type PutV1SettingsJSONRequestBody = UserSettings
// PostV1SettingsTestAiJSONRequestBody defines body for PostV1SettingsTestAi for application/json ContentType.
type PostV1SettingsTestAiJSONRequestBody = TestAIRequest
// PutV1UserzProfileJSONRequestBody defines body for PutV1UserzProfile for application/json ContentType.
type PutV1UserzProfileJSONRequestBody = UserUpdateRequest
// AsServiceVersion returns the union data inside the AggregatedVersion_Worker as a ServiceVersion
func (t AggregatedVersion_Worker) AsServiceVersion() (ServiceVersion, error) {
var body ServiceVersion
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromServiceVersion overwrites any union data inside the AggregatedVersion_Worker as the provided ServiceVersion
func (t *AggregatedVersion_Worker) FromServiceVersion(v ServiceVersion) error {
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeServiceVersion performs a merge with any union data inside the AggregatedVersion_Worker, using the provided ServiceVersion
func (t *AggregatedVersion_Worker) MergeServiceVersion(v ServiceVersion) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
// AsAggregatedVersionWorker1 returns the union data inside the AggregatedVersion_Worker as a AggregatedVersionWorker1
func (t AggregatedVersion_Worker) AsAggregatedVersionWorker1() (AggregatedVersionWorker1, error) {
var body AggregatedVersionWorker1
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromAggregatedVersionWorker1 overwrites any union data inside the AggregatedVersion_Worker as the provided AggregatedVersionWorker1
func (t *AggregatedVersion_Worker) FromAggregatedVersionWorker1(v AggregatedVersionWorker1) error {
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeAggregatedVersionWorker1 performs a merge with any union data inside the AggregatedVersion_Worker, using the provided AggregatedVersionWorker1
func (t *AggregatedVersion_Worker) MergeAggregatedVersionWorker1(v AggregatedVersionWorker1) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
func (t AggregatedVersion_Worker) MarshalJSON() ([]byte, error) {
b, err := t.union.MarshalJSON()
return b, err
}
func (t *AggregatedVersion_Worker) UnmarshalJSON(b []byte) error {
err := t.union.UnmarshalJSON(b)
return err
}
// AsUserSettings0 returns the union data inside the UserSettings as a UserSettings0
func (t UserSettings) AsUserSettings0() (UserSettings0, error) {
var body UserSettings0
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromUserSettings0 overwrites any union data inside the UserSettings as the provided UserSettings0
func (t *UserSettings) FromUserSettings0(v UserSettings0) error {
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeUserSettings0 performs a merge with any union data inside the UserSettings, using the provided UserSettings0
func (t *UserSettings) MergeUserSettings0(v UserSettings0) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
// AsUserSettings1 returns the union data inside the UserSettings as a UserSettings1
func (t UserSettings) AsUserSettings1() (UserSettings1, error) {
var body UserSettings1
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromUserSettings1 overwrites any union data inside the UserSettings as the provided UserSettings1
func (t *UserSettings) FromUserSettings1(v UserSettings1) error {
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeUserSettings1 performs a merge with any union data inside the UserSettings, using the provided UserSettings1
func (t *UserSettings) MergeUserSettings1(v UserSettings1) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
func (t UserSettings) MarshalJSON() ([]byte, error) {
b, err := t.union.MarshalJSON()
if err != nil {
return nil, err
}
object := make(map[string]json.RawMessage)
if t.union != nil {
err = json.Unmarshal(b, &object)
if err != nil {
return nil, err
}
}
if t.AiEnabled != nil {
object["ai_enabled"], err = json.Marshal(t.AiEnabled)
if err != nil {
return nil, fmt.Errorf("error marshaling 'ai_enabled': %w", err)
}
}
if t.AiModel != nil {
object["ai_model"], err = json.Marshal(t.AiModel)
if err != nil {
return nil, fmt.Errorf("error marshaling 'ai_model': %w", err)
}
}
if t.AiProvider != nil {
object["ai_provider"], err = json.Marshal(t.AiProvider)
if err != nil {
return nil, fmt.Errorf("error marshaling 'ai_provider': %w", err)
}
}
if t.ApiKey != nil {
object["api_key"], err = json.Marshal(t.ApiKey)
if err != nil {
return nil, fmt.Errorf("error marshaling 'api_key': %w", err)
}
}
if t.Language != nil {
object["language"], err = json.Marshal(t.Language)
if err != nil {
return nil, fmt.Errorf("error marshaling 'language': %w", err)
}
}
if t.Level != nil {
object["level"], err = json.Marshal(t.Level)
if err != nil {
return nil, fmt.Errorf("error marshaling 'level': %w", err)
}
}
b, err = json.Marshal(object)
return b, err
}
func (t *UserSettings) UnmarshalJSON(b []byte) error {
err := t.union.UnmarshalJSON(b)
if err != nil {
return err
}
object := make(map[string]json.RawMessage)
err = json.Unmarshal(b, &object)
if err != nil {
return err
}
if raw, found := object["ai_enabled"]; found {
err = json.Unmarshal(raw, &t.AiEnabled)
if err != nil {
return fmt.Errorf("error reading 'ai_enabled': %w", err)
}
}
if raw, found := object["ai_model"]; found {
err = json.Unmarshal(raw, &t.AiModel)
if err != nil {
return fmt.Errorf("error reading 'ai_model': %w", err)
}
}
if raw, found := object["ai_provider"]; found {
err = json.Unmarshal(raw, &t.AiProvider)
if err != nil {
return fmt.Errorf("error reading 'ai_provider': %w", err)
}
}
if raw, found := object["api_key"]; found {
err = json.Unmarshal(raw, &t.ApiKey)
if err != nil {
return fmt.Errorf("error reading 'api_key': %w", err)
}
}
if raw, found := object["language"]; found {
err = json.Unmarshal(raw, &t.Language)
if err != nil {
return fmt.Errorf("error reading 'language': %w", err)
}
}
if raw, found := object["level"]; found {
err = json.Unmarshal(raw, &t.Level)
if err != nil {
return fmt.Errorf("error reading 'level': %w", err)
}
}
return err
}
// AsUserUpdateRequest0 returns the union data inside the UserUpdateRequest as a UserUpdateRequest0
func (t UserUpdateRequest) AsUserUpdateRequest0() (UserUpdateRequest0, error) {
var body UserUpdateRequest0
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromUserUpdateRequest0 overwrites any union data inside the UserUpdateRequest as the provided UserUpdateRequest0
func (t *UserUpdateRequest) FromUserUpdateRequest0(v UserUpdateRequest0) error {
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeUserUpdateRequest0 performs a merge with any union data inside the UserUpdateRequest, using the provided UserUpdateRequest0
func (t *UserUpdateRequest) MergeUserUpdateRequest0(v UserUpdateRequest0) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
// AsUserUpdateRequest1 returns the union data inside the UserUpdateRequest as a UserUpdateRequest1
func (t UserUpdateRequest) AsUserUpdateRequest1() (UserUpdateRequest1, error) {
var body UserUpdateRequest1
err := json.Unmarshal(t.union, &body)
return body, err
}
// FromUserUpdateRequest1 overwrites any union data inside the UserUpdateRequest as the provided UserUpdateRequest1
func (t *UserUpdateRequest) FromUserUpdateRequest1(v UserUpdateRequest1) error {
b, err := json.Marshal(v)
t.union = b
return err
}
// MergeUserUpdateRequest1 performs a merge with any union data inside the UserUpdateRequest, using the provided UserUpdateRequest1
func (t *UserUpdateRequest) MergeUserUpdateRequest1(v UserUpdateRequest1) error {
b, err := json.Marshal(v)
if err != nil {
return err
}
merged, err := runtime.JSONMerge(t.union, b)
t.union = merged
return err
}
func (t UserUpdateRequest) MarshalJSON() ([]byte, error) {
b, err := t.union.MarshalJSON()
if err != nil {
return nil, err
}
object := make(map[string]json.RawMessage)
if t.union != nil {
err = json.Unmarshal(b, &object)
if err != nil {
return nil, err
}
}
if t.AiEnabled != nil {
object["ai_enabled"], err = json.Marshal(t.AiEnabled)
if err != nil {
return nil, fmt.Errorf("error marshaling 'ai_enabled': %w", err)
}
}
if t.AiModel != nil {
object["ai_model"], err = json.Marshal(t.AiModel)
if err != nil {
return nil, fmt.Errorf("error marshaling 'ai_model': %w", err)
}
}
if t.AiProvider != nil {
object["ai_provider"], err = json.Marshal(t.AiProvider)
if err != nil {
return nil, fmt.Errorf("error marshaling 'ai_provider': %w", err)
}
}
if t.ApiKey != nil {
object["api_key"], err = json.Marshal(t.ApiKey)
if err != nil {
return nil, fmt.Errorf("error marshaling 'api_key': %w", err)
}
}
if t.CurrentLevel != nil {
object["current_level"], err = json.Marshal(t.CurrentLevel)
if err != nil {
return nil, fmt.Errorf("error marshaling 'current_level': %w", err)
}
}
if t.Email != nil {
object["email"], err = json.Marshal(t.Email)
if err != nil {
return nil, fmt.Errorf("error marshaling 'email': %w", err)
}
}
if t.PreferredLanguage != nil {
object["preferred_language"], err = json.Marshal(t.PreferredLanguage)
if err != nil {
return nil, fmt.Errorf("error marshaling 'preferred_language': %w", err)
}
}
if t.SelectedRoles != nil {
object["selectedRoles"], err = json.Marshal(t.SelectedRoles)
if err != nil {
return nil, fmt.Errorf("error marshaling 'selectedRoles': %w", err)
}
}
if t.Timezone != nil {
object["timezone"], err = json.Marshal(t.Timezone)
if err != nil {
return nil, fmt.Errorf("error marshaling 'timezone': %w", err)
}
}
if t.Username != nil {
object["username"], err = json.Marshal(t.Username)
if err != nil {
return nil, fmt.Errorf("error marshaling 'username': %w", err)
}
}
b, err = json.Marshal(object)
return b, err
}
func (t *UserUpdateRequest) UnmarshalJSON(b []byte) error {
err := t.union.UnmarshalJSON(b)
if err != nil {
return err
}
object := make(map[string]json.RawMessage)
err = json.Unmarshal(b, &object)
if err != nil {
return err
}
if raw, found := object["ai_enabled"]; found {
err = json.Unmarshal(raw, &t.AiEnabled)
if err != nil {
return fmt.Errorf("error reading 'ai_enabled': %w", err)
}
}
if raw, found := object["ai_model"]; found {
err = json.Unmarshal(raw, &t.AiModel)
if err != nil {
return fmt.Errorf("error reading 'ai_model': %w", err)
}
}
if raw, found := object["ai_provider"]; found {
err = json.Unmarshal(raw, &t.AiProvider)
if err != nil {
return fmt.Errorf("error reading 'ai_provider': %w", err)
}
}
if raw, found := object["api_key"]; found {
err = json.Unmarshal(raw, &t.ApiKey)
if err != nil {
return fmt.Errorf("error reading 'api_key': %w", err)
}
}
if raw, found := object["current_level"]; found {
err = json.Unmarshal(raw, &t.CurrentLevel)
if err != nil {
return fmt.Errorf("error reading 'current_level': %w", err)
}
}
if raw, found := object["email"]; found {
err = json.Unmarshal(raw, &t.Email)
if err != nil {
return fmt.Errorf("error reading 'email': %w", err)
}
}
if raw, found := object["preferred_language"]; found {
err = json.Unmarshal(raw, &t.PreferredLanguage)
if err != nil {
return fmt.Errorf("error reading 'preferred_language': %w", err)
}
}
if raw, found := object["selectedRoles"]; found {
err = json.Unmarshal(raw, &t.SelectedRoles)
if err != nil {
return fmt.Errorf("error reading 'selectedRoles': %w", err)
}
}
if raw, found := object["timezone"]; found {
err = json.Unmarshal(raw, &t.Timezone)
if err != nil {
return fmt.Errorf("error reading 'timezone': %w", err)
}
}
if raw, found := object["username"]; found {
err = json.Unmarshal(raw, &t.Username)
if err != nil {
return fmt.Errorf("error reading 'username': %w", err)
}
}
return err
}
// Package config handles application configuration loading from environment variables.
package config
import (
"os"
"reflect"
"sort"
"strconv"
"strings"
"time"
contextutils "quizapp/internal/utils"
"gopkg.in/yaml.v3"
)
// ProviderConfig defines the structure for a single provider
type ProviderConfig struct {
Name string `json:"name" yaml:"name"`
Code string `json:"code" yaml:"code"`
URL string `json:"url,omitempty" yaml:"url,omitempty"`
SupportsGrammar bool `json:"supports_grammar,omitempty" yaml:"supports_grammar,omitempty"`
QuestionBatchSize int `json:"question_batch_size,omitempty" yaml:"question_batch_size,omitempty"`
Models []AIModel `json:"models" yaml:"models"`
}
// AIModel represents an AI model configuration
type AIModel struct {
Name string `json:"name" yaml:"name"`
Code string `json:"code" yaml:"code"`
MaxTokens int `json:"max_tokens,omitempty" yaml:"max_tokens,omitempty"`
}
// QuestionVarietyConfig defines the variety configuration for question generation
type QuestionVarietyConfig struct {
TopicCategories []string `json:"topic_categories" yaml:"topic_categories"`
GrammarFocusByLevel map[string][]string `json:"grammar_focus_by_level" yaml:"grammar_focus_by_level"`
GrammarFocus []string `json:"grammar_focus" yaml:"grammar_focus"`
VocabularyDomains []string `json:"vocabulary_domains" yaml:"vocabulary_domains"`
Scenarios []string `json:"scenarios" yaml:"scenarios"`
StyleModifiers []string `json:"style_modifiers" yaml:"style_modifiers"`
DifficultyModifiers []string `json:"difficulty_modifiers" yaml:"difficulty_modifiers"`
TimeContexts []string `json:"time_contexts" yaml:"time_contexts"`
}
// LanguageLevelConfig represents the levels and descriptions for a specific language
type LanguageLevelConfig struct {
Levels []string `json:"levels" yaml:"levels"`
Descriptions map[string]string `json:"descriptions" yaml:"descriptions"`
}
// AuthConfig represents authentication-related configuration
type AuthConfig struct {
SignupsDisabled bool `json:"signups_disabled" yaml:"signups_disabled"`
AllowedDomains []string `json:"allowed_domains,omitempty" yaml:"allowed_domains,omitempty"`
AllowedEmails []string `json:"allowed_emails,omitempty" yaml:"allowed_emails,omitempty"`
}
// SystemConfig represents system-wide configuration
type SystemConfig struct {
Auth AuthConfig `json:"auth" yaml:"auth"`
}
// Config holds all configuration for the application
type Config struct {
// Server configuration
Server ServerConfig `json:"server" yaml:"server"`
// Database configuration
Database DatabaseConfig `json:"database" yaml:"database"`
// AI Providers and Language Levels
Providers []ProviderConfig `json:"providers" yaml:"providers"`
LanguageLevels map[string]LanguageLevelConfig `json:"language_levels" yaml:"language_levels"`
Variety *QuestionVarietyConfig `json:"variety,omitempty" yaml:"variety,omitempty"`
System *SystemConfig `json:"system,omitempty" yaml:"system,omitempty"`
// OAuth Configuration
GoogleOAuthClientID string `json:"google_oauth_client_id" yaml:"google_oauth_client_id"`
GoogleOAuthClientSecret string `json:"google_oauth_client_secret" yaml:"google_oauth_client_secret"`
GoogleOAuthRedirectURL string `json:"google_oauth_redirect_url" yaml:"google_oauth_redirect_url"`
// OpenTelemetry Configuration
OpenTelemetry OpenTelemetryConfig `json:"open_telemetry" yaml:"open_telemetry"`
// Email Configuration
Email EmailConfig `json:"email" yaml:"email"`
// Internal fields
IsTest bool `json:"is_test" yaml:"is_test"`
}
// ServerConfig represents server configuration
type ServerConfig struct {
Port string `json:"port" yaml:"port"`
WorkerPort string `json:"worker_port" yaml:"worker_port"`
AdminUsername string `json:"admin_username" yaml:"admin_username"`
AdminPassword string `json:"admin_password" yaml:"admin_password"`
SessionSecret string `json:"session_secret" yaml:"session_secret"`
Debug bool `json:"debug" yaml:"debug"`
LogLevel string `json:"log_level" yaml:"log_level"`
WorkerBaseURL string `json:"worker_base_url" yaml:"worker_base_url"`
WorkerInternalURL string `json:"worker_internal_url" yaml:"worker_internal_url"`
BackendBaseURL string `json:"backend_base_url" yaml:"backend_base_url"`
AppBaseURL string `json:"app_base_url" yaml:"app_base_url"`
MaxAIConcurrent int `json:"max_ai_concurrent" yaml:"max_ai_concurrent"`
MaxAIPerUser int `json:"max_ai_per_user" yaml:"max_ai_per_user"`
CORSOrigins []string `json:"cors_origins" yaml:"cors_origins"`
QuestionRefillThreshold int `json:"question_refill_threshold" yaml:"question_refill_threshold"`
// DailyFreshQuestionRatio controls the minimum fraction of fresh (never-seen)
// questions to aim for when refilling question pools (0.0 - 1.0). Example: 0.35
// means at least 35% fresh questions when refilling.
DailyFreshQuestionRatio float64 `json:"daily_fresh_question_ratio" yaml:"daily_fresh_question_ratio"`
MaxHistory int `json:"max_history" yaml:"max_history"`
MaxActivityLogs int `json:"max_activity_logs" yaml:"max_activity_logs"`
DailyRepeatAvoidDays int `json:"daily_repeat_avoid_days" yaml:"daily_repeat_avoid_days"`
// DailyHorizonDays controls how many days ahead the worker will assign
// daily questions (e.g. 0 = today only, 1 = today+1, ...). If unset or
// <= 0 the worker will fall back to the DAILY_HORIZON_DAYS environment
// variable (default 1).
DailyHorizonDays int `json:"daily_horizon_days" yaml:"daily_horizon_days"`
}
// GetLanguages returns a slice of all supported languages (derived from language_levels keys)
func (c *Config) GetLanguages() []string {
if c.LanguageLevels == nil {
return []string{}
}
languages := make([]string, 0, len(c.LanguageLevels))
for lang := range c.LanguageLevels {
languages = append(languages, lang)
}
sort.Strings(languages)
return languages
}
// GetLevelsForLanguage returns the levels for a specific language
func (c *Config) GetLevelsForLanguage(language string) []string {
if c.LanguageLevels == nil {
return []string{}
}
langConfig, exists := c.LanguageLevels[language]
if !exists {
return []string{}
}
return langConfig.Levels
}
// GetLevelDescriptionsForLanguage returns the level descriptions for a specific language
func (c *Config) GetLevelDescriptionsForLanguage(language string) map[string]string {
if c.LanguageLevels == nil {
return map[string]string{}
}
langConfig, exists := c.LanguageLevels[language]
if !exists {
return map[string]string{}
}
return langConfig.Descriptions
}
// GetAllLevels returns all unique levels across all languages
func (c *Config) GetAllLevels() []string {
if c.LanguageLevels == nil {
return []string{}
}
levelSet := make(map[string]bool)
for _, langConfig := range c.LanguageLevels {
for _, level := range langConfig.Levels {
levelSet[level] = true
}
}
levels := make([]string, 0, len(levelSet))
for level := range levelSet {
levels = append(levels, level)
}
sort.Strings(levels)
return levels
}
// GetAllLevelDescriptions returns all unique level descriptions across all languages
func (c *Config) GetAllLevelDescriptions() map[string]string {
if c.LanguageLevels == nil {
return map[string]string{}
}
descriptions := make(map[string]string)
for _, langConfig := range c.LanguageLevels {
for level, description := range langConfig.Descriptions {
descriptions[level] = description
}
}
return descriptions
}
// Languages returns all supported languages
func (c *Config) Languages() []string {
return c.GetLanguages()
}
// Levels returns all unique levels
func (c *Config) Levels() []string {
return c.GetAllLevels()
}
// LevelDescriptions returns all unique level descriptions
func (c *Config) LevelDescriptions() map[string]string {
return c.GetAllLevelDescriptions()
}
// IsSignupDisabled returns whether signups are disabled based on configuration
func (c *Config) IsSignupDisabled() bool {
if c.System == nil {
return false // Default to enabled if no config
}
return c.System.Auth.SignupsDisabled
}
// IsEmailAllowed checks if an email is allowed for OAuth signup override
func (c *Config) IsEmailAllowed(email string) bool {
if c.System == nil || c.System.Auth.AllowedEmails == nil {
return false
}
normalizedEmail := strings.ToLower(strings.TrimSpace(email))
for _, allowedEmail := range c.System.Auth.AllowedEmails {
if strings.ToLower(strings.TrimSpace(allowedEmail)) == normalizedEmail {
return true
}
}
return false
}
// IsDomainAllowed checks if a domain is allowed for OAuth signup override
func (c *Config) IsDomainAllowed(domain string) bool {
if c.System == nil || c.System.Auth.AllowedDomains == nil {
return false
}
normalizedDomain := strings.ToLower(strings.TrimSpace(domain))
for _, allowedDomain := range c.System.Auth.AllowedDomains {
if strings.ToLower(strings.TrimSpace(allowedDomain)) == normalizedDomain {
return true
}
}
return false
}
// IsOAuthSignupAllowed checks if OAuth signup is allowed for a given email
func (c *Config) IsOAuthSignupAllowed(email string) bool {
if c.System == nil {
return false
}
// If signups are not disabled, OAuth signup is always allowed
if !c.System.Auth.SignupsDisabled {
return true
}
// If signups are disabled, check whitelist
normalizedEmail := strings.ToLower(strings.TrimSpace(email))
// Use the shared email validation function
if !contextutils.IsValidEmail(normalizedEmail) {
return false
}
// Check if email is directly whitelisted
if c.IsEmailAllowed(normalizedEmail) {
return true
}
// Extract domain from email and check if domain is whitelisted
parts := strings.Split(normalizedEmail, "@")
domain := parts[1]
return c.IsDomainAllowed(domain)
}
// OpenTelemetryConfig holds all OpenTelemetry-related configuration
type OpenTelemetryConfig struct {
Endpoint string `json:"endpoint" yaml:"endpoint"` // Default: "http://localhost:4317"
Protocol string `json:"protocol" yaml:"protocol"` // "grpc" or "http", default: "grpc"
Insecure bool `json:"insecure" yaml:"insecure"` // Default: true (for localhost)
Headers map[string]string `json:"headers" yaml:"headers"` // For authenticated endpoints
ServiceName string `json:"service_name" yaml:"service_name"` // Default: "quiz-backend" or "quiz-worker"
ServiceVersion string `json:"service_version" yaml:"service_version"` // From version package
EnableTracing bool `json:"enable_tracing" yaml:"enable_tracing"` // Default: true
EnableMetrics bool `json:"enable_metrics" yaml:"enable_metrics"` // Default: true
EnableLogging bool `json:"enable_logging" yaml:"enable_logging"` // Default: true (future)
SamplingRate float64 `json:"sampling_rate" yaml:"sampling_rate"` // Default: 1.0 (100%)
}
// DatabaseConfig represents database configuration
type DatabaseConfig struct {
URL string `json:"url" yaml:"url"`
MaxOpenConns int `json:"max_open_conns" yaml:"max_open_conns"` // Maximum number of open connections to the database
MaxIdleConns int `json:"max_idle_conns" yaml:"max_idle_conns"` // Maximum number of idle connections in the pool
ConnMaxLifetime time.Duration `json:"conn_max_lifetime" yaml:"conn_max_lifetime"` // Maximum amount of time a connection may be reused
}
// EmailConfig represents email/SMTP configuration
type EmailConfig struct {
SMTP SMTPConfig `json:"smtp" yaml:"smtp"`
DailyReminder DailyReminderConfig `json:"daily_reminder" yaml:"daily_reminder"`
Enabled bool `json:"enabled" yaml:"enabled"`
}
// SMTPConfig represents SMTP server configuration
type SMTPConfig struct {
Host string `json:"host" yaml:"host"`
Port int `json:"port" yaml:"port"`
Username string `json:"username" yaml:"username"`
Password string `json:"password" yaml:"password"`
FromAddress string `json:"from_address" yaml:"from_address"`
FromName string `json:"from_name" yaml:"from_name"`
}
// DailyReminderConfig represents daily reminder email configuration
type DailyReminderConfig struct {
Enabled bool `json:"enabled" yaml:"enabled"`
Hour int `json:"hour" yaml:"hour"` // Hour of day to send (0-23)
}
// NewConfig loads configuration from YAML file first, then overrides with environment variables
func NewConfig() (result0 *Config, err error) {
// Load config from YAML file
config, err := loadConfigWithOverrides()
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to load config: %w", err)
}
// Override with environment variables
config.overrideFromEnv()
return config, nil
}
// overrideFromEnv overrides config values with environment variables using reflection
func (c *Config) overrideFromEnv() {
overrideStructFromEnv(c)
}
// overrideStructFromEnv recursively overrides struct fields with environment variables
func overrideStructFromEnv(v interface{}) {
overrideStructFromEnvWithPrefix(v, "")
}
// overrideStructFromEnvWithPrefix recursively overrides struct fields with environment variables
func overrideStructFromEnvWithPrefix(v interface{}, prefix string) {
val := reflect.ValueOf(v)
if val.Kind() == reflect.Ptr {
val = val.Elem()
}
if val.Kind() != reflect.Struct {
return
}
typ := val.Type()
for i := 0; i < val.NumField(); i++ {
field := val.Field(i)
fieldType := typ.Field(i)
// Skip unexported fields
if !field.CanSet() {
continue
}
// Get the yaml tag for the field
yamlTag := fieldType.Tag.Get("yaml")
if yamlTag == "" || yamlTag == "-" {
continue
}
// Convert yaml tag to environment variable name
envKey := strings.ToUpper(strings.ReplaceAll(yamlTag, "-", "_"))
if prefix != "" {
envKey = prefix + "_" + envKey
}
switch field.Kind() {
case reflect.String:
if envVal := os.Getenv(envKey); envVal != "" {
field.SetString(envVal)
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
if envVal := os.Getenv(envKey); envVal != "" {
if intVal, err := strconv.ParseInt(envVal, 10, 64); err == nil {
field.SetInt(intVal)
}
}
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
if envVal := os.Getenv(envKey); envVal != "" {
if uintVal, err := strconv.ParseUint(envVal, 10, 64); err == nil {
field.SetUint(uintVal)
}
}
case reflect.Float32, reflect.Float64:
if envVal := os.Getenv(envKey); envVal != "" {
if floatVal, err := strconv.ParseFloat(envVal, 64); err == nil {
field.SetFloat(floatVal)
}
}
case reflect.Bool:
if envVal := os.Getenv(envKey); envVal != "" {
if boolVal, err := strconv.ParseBool(envVal); err == nil {
field.SetBool(boolVal)
}
}
case reflect.Slice:
if envVal := os.Getenv(envKey); envVal != "" {
// Handle string slices (like CORS_ORIGINS)
if field.Type().Elem().Kind() == reflect.String {
slice := strings.Split(envVal, ",")
field.Set(reflect.ValueOf(slice))
}
}
case reflect.Struct:
// Recursively process nested structs with the field name as prefix
if field.CanAddr() {
fieldPrefix := strings.ToUpper(strings.ReplaceAll(yamlTag, "-", "_"))
if prefix != "" {
fieldPrefix = prefix + "_" + fieldPrefix
}
overrideStructFromEnvWithPrefix(field.Addr().Interface(), fieldPrefix)
}
case reflect.Ptr:
// Handle pointer to struct
if !field.IsNil() && field.Elem().Kind() == reflect.Struct {
fieldPrefix := strings.ToUpper(strings.ReplaceAll(yamlTag, "-", "_"))
if prefix != "" {
fieldPrefix = prefix + "_" + fieldPrefix
}
overrideStructFromEnvWithPrefix(field.Interface(), fieldPrefix)
}
}
}
}
// loadConfigWithOverrides loads the config file with potential local overrides
func loadConfigWithOverrides() (result0 *Config, err error) {
// Try to load from environment variable first
if envPath := os.Getenv("QUIZ_CONFIG_FILE"); envPath != "" {
config, err := loadConfigFromFile(envPath)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to load config from %s: %w", envPath, err)
}
return config, nil
}
// If no environment variable is set, try default config.yaml
return loadConfigFromFile("config.yaml")
}
// loadConfigFromFile loads configuration from a specific file
func loadConfigFromFile(path string) (result0 *Config, err error) {
yamlFile, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var config Config
if err := yaml.Unmarshal(yamlFile, &config); err != nil {
return nil, err
}
return &config, nil
}
// Package database provides database connection and migration functionality.
package database
import (
"context"
"database/sql"
"errors"
"fmt"
"net/url"
"os"
"path/filepath"
"strings"
"sync"
"quizapp/internal/config"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
// Import PostgreSQL driver for database/sql
_ "github.com/lib/pq"
// Add golang-migrate imports
"github.com/golang-migrate/migrate/v4"
_ "github.com/golang-migrate/migrate/v4/database/postgres" // required for golang-migrate postgres driver
_ "github.com/golang-migrate/migrate/v4/source/file" // required for golang-migrate file source
// OpenTelemetry SQL instrumentation
"go.nhat.io/otelsql"
"go.opentelemetry.io/otel/attribute"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
)
// Manager handles database operations with proper logging
type Manager struct {
logger *observability.Logger
}
var (
otelDriverNameCache string
otelDriverOnce sync.Once
otelDriverErr error
)
// NewManager creates a new database manager with the provided logger
func NewManager(logger *observability.Logger) *Manager {
return &Manager{
logger: logger,
}
}
// ErrTableAlreadyExists is returned when trying to create a table that already exists
var ErrTableAlreadyExists = errors.New("table already exists")
// DefaultDatabaseConfig returns the default database configuration
func DefaultDatabaseConfig() config.DatabaseConfig {
config := config.DatabaseConfig{
MaxOpenConns: 25,
MaxIdleConns: 5,
ConnMaxLifetime: config.DatabaseConnMaxLifetime,
}
// Check for TEST_DATABASE_URL first (for tests)
if testURL := os.Getenv("TEST_DATABASE_URL"); testURL != "" {
config.URL = testURL
}
return config
}
// InitDB initializes and returns a database connection with migrations
func (dm *Manager) InitDB(databaseURL string) (result0 *sql.DB, err error) {
dbName := extractDatabaseName(databaseURL)
_, span := observability.TraceDatabaseFunction(context.Background(), "InitDB",
attribute.String("db.url", databaseURL),
attribute.String("db.name", dbName),
attribute.String("db.system", "postgresql"),
attribute.Bool("migrations.enabled", true),
)
defer observability.FinishSpan(span, &err)
config := DefaultDatabaseConfig()
config.URL = databaseURL
return dm.InitDBWithConfig(config)
}
// InitDBWithConfig initializes and returns a database connection with migrations and custom config
func (dm *Manager) InitDBWithConfig(config config.DatabaseConfig) (result0 *sql.DB, err error) {
dbName := extractDatabaseName(config.URL)
_, span := observability.TraceDatabaseFunction(context.Background(), "InitDBWithConfig",
attribute.String("db.url", config.URL),
attribute.String("db.name", dbName),
attribute.String("db.system", "postgresql"),
attribute.Bool("migrations.enabled", true),
attribute.Int("db.max_open_conns", config.MaxOpenConns),
attribute.Int("db.max_idle_conns", config.MaxIdleConns),
attribute.String("db.conn_max_lifetime", config.ConnMaxLifetime.String()),
)
defer observability.FinishSpan(span, &err)
db, err := dm.InitDBWithoutMigrations(config)
if err != nil {
return nil, err
}
if err := dm.RunMigrations(db); err != nil {
return nil, err
}
return db, nil
}
// extractDatabaseName extracts the database name from a PostgreSQL connection string
func extractDatabaseName(databaseURL string) string {
// Try to parse as URL first
if u, err := url.Parse(databaseURL); err == nil && u.Path != "" {
// Remove leading slash and return the database name
dbName := strings.TrimPrefix(u.Path, "/")
if dbName != "" {
return dbName
}
}
// Fallback: try to extract from connection string format
// postgres://user:pass@host:port/dbname?sslmode=disable
if strings.Contains(databaseURL, "/") {
parts := strings.Split(databaseURL, "/")
if len(parts) > 1 {
// Get the last part and remove query parameters
dbPart := parts[len(parts)-1]
if idx := strings.Index(dbPart, "?"); idx != -1 {
return dbPart[:idx]
}
return dbPart
}
}
// Default fallback
return "quiz_db"
}
// InitDBWithoutMigrations initializes and returns a database connection without running migrations
func (dm *Manager) InitDBWithoutMigrations(config config.DatabaseConfig) (result0 *sql.DB, err error) {
// Extract database name for OpenTelemetry tracing
ctx, span := observability.TraceDatabaseFunction(context.Background(), "InitDBWithoutMigrations",
attribute.String("database.url", config.URL),
)
defer observability.FinishSpan(span, &err)
// Register OpenTelemetry SQL driver once per process and reuse the name
otelDriverOnce.Do(func() {
otelDriverNameCache, otelDriverErr = otelsql.Register("postgres",
otelsql.WithDatabaseName(extractDatabaseName(config.URL)),
otelsql.TraceQueryWithArgs(),
otelsql.WithSystem(semconv.DBSystemPostgreSQL),
otelsql.TraceRowsAffected(),
)
})
if otelDriverErr != nil {
return nil, contextutils.WrapError(otelDriverErr, "failed to register otelsql driver")
}
// Connect to database using the instrumented driver
db, err := sql.Open(otelDriverNameCache, config.URL)
if err != nil {
return nil, contextutils.WrapError(err, "failed to open database connection")
}
// Set connection pool settings
db.SetMaxOpenConns(config.MaxOpenConns)
db.SetMaxIdleConns(config.MaxIdleConns)
db.SetConnMaxLifetime(config.ConnMaxLifetime)
// Test the connection
if err := db.Ping(); err != nil {
if closeErr := db.Close(); closeErr != nil {
dm.logger.Error(ctx, "Failed to close database connection after ping failure", closeErr)
}
return nil, contextutils.WrapError(err, "failed to ping database")
}
dm.logger.Info(ctx, "Database connection established without migrations", map[string]interface{}{
"max_open_conns": config.MaxOpenConns,
"max_idle_conns": config.MaxIdleConns,
"conn_max_lifetime": config.ConnMaxLifetime,
})
return db, nil
}
// RunMigrations executes the application SQL schema and any pending migrations
func (dm *Manager) RunMigrations(db *sql.DB) (err error) {
_, span := observability.TraceDatabaseFunction(context.Background(), "RunMigrations",
attribute.String("db.system", "postgresql"),
attribute.String("migration.type", "application_schema"),
)
defer observability.FinishSpan(span, &err)
dm.logger.Info(context.Background(), "Starting database migrations...")
// Run the main application schema first
if err := dm.runApplicationSchema(db); err != nil {
return contextutils.WrapError(err, "failed to run application schema")
}
dm.logger.Info(context.Background(), "Application schema applied successfully")
// Run golang-migrate migrations if directory exists
if err := dm.runGolangMigrate(); err != nil {
return contextutils.WrapError(err, "failed to run golang-migrate migrations")
}
dm.logger.Info(context.Background(), "Database migrations completed successfully")
return nil
}
// runGolangMigrate runs migrations using golang-migrate from migrations
func (dm *Manager) runGolangMigrate() (err error) {
migrationsPath, err := dm.GetMigrationsPath()
if err != nil {
dm.logger.Error(context.Background(), "Could not find migrations path", err)
return err // HARD FAIL if migrations path is not set
}
_, span := observability.TraceDatabaseFunction(context.Background(), "runGolangMigrate",
attribute.String("db.system", "postgresql"),
attribute.String("migration.type", "golang_migrate"),
attribute.String("migration.path", migrationsPath),
)
defer observability.FinishSpan(span, &err)
if migrationsPath == "" {
err = errors.New("no golang-migrate migrations directory found")
dm.logger.Error(context.Background(), "No golang-migrate migrations directory found, hard fail!", err)
return err // HARD FAIL
}
// Check if migrations directory exists and has migration files
if _, statErr := os.Stat(migrationsPath); os.IsNotExist(statErr) {
dm.logger.Error(context.Background(), "Migrations directory does not exist", statErr)
err = statErr // HARD FAIL if directory does not exist
return err
}
// Check if there are any migration files in the directory
files, err := os.ReadDir(migrationsPath)
if err != nil {
dm.logger.Error(context.Background(), "Could not read migrations directory", err)
return err // HARD FAIL
}
// Check if there are any .up.sql files
hasMigrationFiles := false
migrationFileCount := 0
for _, file := range files {
if !file.IsDir() && strings.HasSuffix(file.Name(), ".up.sql") {
hasMigrationFiles = true
migrationFileCount++
}
}
span.SetAttributes(attribute.Int("migration.files.count", migrationFileCount))
if !hasMigrationFiles {
dm.logger.Info(context.Background(), fmt.Sprintf("No migration files found in %s. Skipping golang-migrate.", migrationsPath))
return nil
}
dbURL := os.Getenv("DATABASE_URL")
if dbURL == "" {
dbURL = os.Getenv("TEST_DATABASE_URL")
}
if dbURL == "" {
err = errors.New("database_url or test_database_url must be set for migrations")
return err
}
// Use file:// scheme with absolute path for golang-migrate
// Convert to file:// URL format - use absolute path
migrationSourceURL := "file://" + filepath.ToSlash(migrationsPath)
// Debug logging
dm.logger.Info(context.Background(), "Migration paths", map[string]interface{}{
"migrations_path": migrationsPath,
"source_url": migrationSourceURL,
"db_url": dbURL,
})
m, err := migrate.New(
migrationSourceURL,
dbURL,
)
if err != nil {
err = contextutils.WrapError(err, "failed to initialize golang-migrate")
return err
}
defer func() {
if _, closeErr := m.Close(); closeErr != nil {
dm.logger.Error(context.Background(), "Error closing migration", closeErr)
}
}()
err = m.Up()
if err != nil && err != migrate.ErrNoChange {
err = contextutils.WrapError(err, "golang-migrate up failed")
return err
}
if err == migrate.ErrNoChange {
dm.logger.Info(context.Background(), "No new golang-migrate migrations to apply.")
} else {
dm.logger.Info(context.Background(), "golang-migrate migrations applied successfully.")
}
return nil
}
// runApplicationSchema executes the main application schema.sql
func (dm *Manager) runApplicationSchema(db *sql.DB) (err error) {
schemaPath, err := dm.getSchemaPath()
if err != nil {
err = contextutils.WrapError(err, "failed to find schema file")
return err
}
_, span := observability.TraceDatabaseFunction(context.Background(), "runApplicationSchema",
attribute.String("db.system", "postgresql"),
attribute.String("migration.type", "application_schema"),
attribute.String("schema.path", schemaPath),
)
defer observability.FinishSpan(span, &err)
// Get the schema file path relative to the project root
schemaPath, err = dm.getSchemaPath()
if err != nil {
err = contextutils.WrapError(err, "failed to find schema file")
return err
}
// Read the schema file
schemaSQL, err := os.ReadFile(schemaPath)
if err != nil {
err = contextutils.WrapError(err, "failed to read schema file")
return err
}
span.SetAttributes(attribute.Int("schema.file.size", len(schemaSQL)))
// Parse SQL statements more carefully to handle comments and multi-line statements
statements := dm.parseSchemaStatements(string(schemaSQL))
span.SetAttributes(attribute.Int("schema.statements.count", len(statements)))
// Execute table creation statements first
var indexStatements []string
for _, statement := range statements {
statement = strings.TrimSpace(statement)
if statement == "" {
continue
}
// Separate index creation from table creation
if strings.HasPrefix(strings.ToUpper(statement), "CREATE INDEX") {
indexStatements = append(indexStatements, statement)
continue
}
_, execErr := db.Exec(statement)
if execErr != nil {
// For backwards compatibility, ignore table exists errors
if !dm.isTableExistsError(execErr) {
err = contextutils.WrapErrorf(execErr, "failed to execute schema statement: %s", statement)
return err
}
}
}
span.SetAttributes(attribute.Int("schema.index_statements.count", len(indexStatements)))
// Now execute index creation statements
for _, statement := range indexStatements {
_, execErr := db.Exec(statement)
if execErr != nil {
// For backwards compatibility, ignore index exists errors
if !dm.isTableExistsError(execErr) {
err = contextutils.WrapErrorf(execErr, "failed to execute index statement: %s", statement)
return err
}
}
}
return nil
}
// getSchemaPath finds the schema.sql file relative to the project root
func (dm *Manager) getSchemaPath() (result0 string, err error) {
_, span := observability.TraceDatabaseFunction(context.Background(), "getSchemaPath",
attribute.String("file.name", "schema.sql"),
)
defer observability.FinishSpan(span, &err)
// Start from the current directory and work up to find schema.sql
currentDir, err := os.Getwd()
if err != nil {
return "", err
}
span.SetAttributes(attribute.String("search.start_dir", currentDir))
for {
schemaPath := filepath.Join(currentDir, "schema.sql")
if _, statErr := os.Stat(schemaPath); statErr == nil {
span.SetAttributes(attribute.String("schema.found_path", schemaPath))
return schemaPath, nil
}
// Move up one directory
parentDir := filepath.Dir(currentDir)
if parentDir == currentDir {
// We've reached the root directory
span.SetAttributes(attribute.String("search.result", "not_found"))
err = contextutils.ErrorWithContextf("schema.sql not found in any parent directory")
return "", err
}
currentDir = parentDir
}
}
// parseSchemaStatements parses SQL statements from a schema file
func (dm *Manager) parseSchemaStatements(schemaSQL string) []string {
_, span := observability.TraceDatabaseFunction(context.Background(), "parseSchemaStatements",
attribute.Int("input.length", len(schemaSQL)),
)
defer span.End()
// Remove comments and normalize whitespace
lines := strings.Split(schemaSQL, "\n")
var cleanedLines []string
inComment := false
for _, line := range lines {
line = strings.TrimSpace(line)
// Skip empty lines
if line == "" {
continue
}
// Handle multi-line comments
if strings.HasPrefix(line, "/*") {
inComment = true
continue
}
if strings.HasSuffix(line, "*/") {
inComment = false
continue
}
if inComment {
continue
}
// Skip single-line comments
if strings.HasPrefix(line, "--") {
continue
}
// Remove inline comments (comments that appear after SQL code)
if commentIndex := strings.Index(line, "--"); commentIndex != -1 {
line = strings.TrimSpace(line[:commentIndex])
}
cleanedLines = append(cleanedLines, line)
}
// Join lines and split by semicolon
cleanedSQL := strings.Join(cleanedLines, " ")
statements := strings.Split(cleanedSQL, ";")
var result []string
for _, stmt := range statements {
stmt = strings.TrimSpace(stmt)
if stmt != "" {
result = append(result, stmt)
}
}
span.SetAttributes(attribute.Int("statements.parsed", len(result)))
return result
}
// isTableExistsError checks if the error is due to a table already existing
func (dm *Manager) isTableExistsError(err error) bool {
_, span := observability.TraceDatabaseFunction(context.Background(), "isTableExistsError")
defer span.End()
// Check for the sentinel error first
if errors.Is(err, ErrTableAlreadyExists) {
return true
}
// Fallback to string matching for backwards compatibility
return strings.Contains(err.Error(), "already exists")
}
// GetMigrationsPath returns the path to the migrations directory
func (dm *Manager) GetMigrationsPath() (result0 string, err error) {
_, span := observability.TraceDatabaseFunction(context.Background(), "GetMigrationsPath",
attribute.String("migration.dir.name", "migrations"),
)
defer observability.FinishSpan(span, &err)
// Start from the current directory and work up to find migrations directory
currentDir, err := os.Getwd()
if err != nil {
return "", err
}
span.SetAttributes(attribute.String("search.start_dir", currentDir))
for {
migrationsPath := filepath.Join(currentDir, "migrations")
if _, statErr := os.Stat(migrationsPath); statErr == nil {
span.SetAttributes(attribute.String("migration.found_path", migrationsPath))
return migrationsPath, nil
}
// Move up one directory
parentDir := filepath.Dir(currentDir)
if parentDir == currentDir {
// We've reached the root directory
span.SetAttributes(attribute.String("search.result", "not_found"))
err = contextutils.ErrorWithContextf("migrations directory not found in any parent directory")
return "", err
}
currentDir = parentDir
}
}
// Package di provides dependency injection container for managing service lifecycle and dependencies.
package di
import (
"context"
"database/sql"
"sync"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
)
// ServiceContainerInterface defines the interface for service containers
type ServiceContainerInterface interface {
GetService(name string) (interface{}, error)
GetUserService() (services.UserServiceInterface, error)
GetQuestionService() (services.QuestionServiceInterface, error)
GetLearningService() (services.LearningServiceInterface, error)
GetAIService() (services.AIServiceInterface, error)
GetWorkerService() (services.WorkerServiceInterface, error)
GetDailyQuestionService() (services.DailyQuestionServiceInterface, error)
GetOAuthService() (*services.OAuthService, error)
GetGenerationHintService() (services.GenerationHintServiceInterface, error)
GetEmailService() (services.EmailServiceInterface, error)
GetDatabase() *sql.DB
GetConfig() *config.Config
GetLogger() *observability.Logger
Initialize(ctx context.Context) error
Shutdown(ctx context.Context) error
EnsureAdminUser(ctx context.Context) error
}
// ServiceContainer manages all service dependencies and lifecycle
type ServiceContainer struct {
cfg *config.Config
logger *observability.Logger
dbManager *database.Manager
db *sql.DB
services map[string]interface{}
mu sync.RWMutex
shutdownFuncs []func(context.Context) error
}
// NewServiceContainer creates a new dependency injection container
func NewServiceContainer(cfg *config.Config, logger *observability.Logger) *ServiceContainer {
return &ServiceContainer{
cfg: cfg,
logger: logger,
services: make(map[string]interface{}),
}
}
// Initialize sets up all services and their dependencies
func (sc *ServiceContainer) Initialize(ctx context.Context) error {
sc.mu.Lock()
defer sc.mu.Unlock()
// Initialize database
sc.dbManager = database.NewManager(sc.logger)
db, err := sc.dbManager.InitDBWithConfig(sc.cfg.Database)
if err != nil {
return contextutils.WrapErrorf(err, "failed to initialize database")
}
sc.db = db
sc.shutdownFuncs = append(sc.shutdownFuncs, func(_ context.Context) error {
return db.Close()
})
// Initialize core services
sc.initializeServices(ctx)
// Startup lifecycle services
if err := sc.startupServices(ctx); err != nil {
// Cleanup on failure
_ = sc.cleanup(ctx)
return contextutils.WrapErrorf(err, "failed to startup services")
}
return nil
}
// GetService retrieves a service by name with type assertion
func (sc *ServiceContainer) GetService(name string) (interface{}, error) {
sc.mu.RLock()
defer sc.mu.RUnlock()
service, exists := sc.services[name]
if !exists {
return nil, contextutils.ErrorWithContextf("service %s not found", name)
}
return service, nil
}
// GetServiceAs performs type-safe service retrieval
func GetServiceAs[T any](sc *ServiceContainer, name string) (T, error) {
var zero T
service, err := sc.GetService(name)
if err != nil {
return zero, err
}
typed, ok := service.(T)
if !ok {
return zero, contextutils.ErrorWithContextf("service %s is not of expected type %T", name, zero)
}
return typed, nil
}
// GetUserService returns the user service
func (sc *ServiceContainer) GetUserService() (services.UserServiceInterface, error) {
return GetServiceAs[services.UserServiceInterface](sc, "user")
}
// GetQuestionService returns the question service
func (sc *ServiceContainer) GetQuestionService() (services.QuestionServiceInterface, error) {
return GetServiceAs[services.QuestionServiceInterface](sc, "question")
}
// GetLearningService returns the learning service
func (sc *ServiceContainer) GetLearningService() (services.LearningServiceInterface, error) {
return GetServiceAs[services.LearningServiceInterface](sc, "learning")
}
// GetAIService returns the AI service
func (sc *ServiceContainer) GetAIService() (services.AIServiceInterface, error) {
return GetServiceAs[services.AIServiceInterface](sc, "ai")
}
// GetWorkerService returns the worker service
func (sc *ServiceContainer) GetWorkerService() (services.WorkerServiceInterface, error) {
return GetServiceAs[services.WorkerServiceInterface](sc, "worker")
}
// GetDailyQuestionService returns the daily question service
func (sc *ServiceContainer) GetDailyQuestionService() (services.DailyQuestionServiceInterface, error) {
return GetServiceAs[services.DailyQuestionServiceInterface](sc, "daily_question")
}
// GetOAuthService returns the OAuth service
func (sc *ServiceContainer) GetOAuthService() (*services.OAuthService, error) {
service, err := sc.GetService("oauth")
if err != nil {
return nil, err
}
oauthService, ok := service.(*services.OAuthService)
if !ok {
return nil, contextutils.ErrorWithContextf("oauth service has incorrect type")
}
return oauthService, nil
}
// GetGenerationHintService returns the generation hint service
func (sc *ServiceContainer) GetGenerationHintService() (services.GenerationHintServiceInterface, error) {
return GetServiceAs[services.GenerationHintServiceInterface](sc, "generation_hint")
}
// GetEmailService returns the email service
func (sc *ServiceContainer) GetEmailService() (services.EmailServiceInterface, error) {
return GetServiceAs[services.EmailServiceInterface](sc, "email")
}
// GetDatabase returns the database instance
func (sc *ServiceContainer) GetDatabase() *sql.DB {
return sc.db
}
// GetConfig returns the configuration
func (sc *ServiceContainer) GetConfig() *config.Config {
return sc.cfg
}
// GetLogger returns the logger
func (sc *ServiceContainer) GetLogger() *observability.Logger {
return sc.logger
}
// Shutdown gracefully shuts down all services
func (sc *ServiceContainer) Shutdown(ctx context.Context) error {
sc.mu.Lock()
defer sc.mu.Unlock()
return sc.cleanup(ctx)
}
// startupServices starts all services that implement the Lifecycle interface
func (sc *ServiceContainer) startupServices(ctx context.Context) error {
// Check each service to see if it implements Lifecycle interface
for name, service := range sc.services {
if lifecycleService, ok := service.(interface{ Startup(context.Context) error }); ok {
sc.logger.Info(ctx, "Starting service", map[string]interface{}{"service": name})
if err := lifecycleService.Startup(ctx); err != nil {
return contextutils.WrapErrorf(err, "failed to startup service %s", name)
}
sc.logger.Info(ctx, "Service started successfully", map[string]interface{}{"service": name})
}
}
return nil
}
// cleanup handles shutdown of all services
func (sc *ServiceContainer) cleanup(ctx context.Context) error {
var errors []error
// Shutdown lifecycle services first (in reverse order)
for name := range sc.services {
if lifecycleService, ok := sc.services[name].(interface{ Shutdown(context.Context) error }); ok {
sc.logger.Info(ctx, "Shutting down service", map[string]interface{}{"service": name})
if err := lifecycleService.Shutdown(ctx); err != nil {
sc.logger.Error(ctx, "Failed to shutdown service", err, map[string]interface{}{"service": name})
errors = append(errors, contextutils.WrapErrorf(err, "service %s shutdown failed", name))
} else {
sc.logger.Info(ctx, "Service shutdown successfully", map[string]interface{}{"service": name})
}
}
}
// Shutdown services in reverse order of initialization
for i := len(sc.shutdownFuncs) - 1; i >= 0; i-- {
if err := sc.shutdownFuncs[i](ctx); err != nil {
errors = append(errors, err)
}
}
if len(errors) > 0 {
return contextutils.ErrorWithContextf("shutdown errors: %v", errors)
}
return nil
}
// initializeServices sets up all service dependencies
func (sc *ServiceContainer) initializeServices(_ context.Context) {
// Core services that don't depend on other services
userService := services.NewUserServiceWithLogger(sc.db, sc.cfg, sc.logger)
sc.services["user"] = userService
// Learning service depends on user service
learningService := services.NewLearningServiceWithLogger(sc.db, sc.cfg, sc.logger)
sc.services["learning"] = learningService
// Question service depends on learning service
questionService := services.NewQuestionServiceWithLogger(sc.db, learningService, sc.cfg, sc.logger)
sc.services["question"] = questionService
// Daily question service depends on question and learning services
dailyQuestionService := services.NewDailyQuestionService(sc.db, sc.logger, questionService, learningService)
sc.services["daily_question"] = dailyQuestionService
// AI service
aiService := services.NewAIService(sc.cfg, sc.logger)
sc.services["ai"] = aiService
// Worker service
workerService := services.NewWorkerServiceWithLogger(sc.db, sc.logger)
sc.services["worker"] = workerService
// Generation hint service
generationHintService := services.NewGenerationHintService(sc.db, sc.logger)
sc.services["generation_hint"] = generationHintService
// OAuth service
oauthService := services.NewOAuthServiceWithLogger(sc.cfg, sc.logger)
sc.services["oauth"] = oauthService
// Email service
emailService := services.CreateEmailService(sc.cfg, sc.logger)
sc.services["email"] = emailService
// Register shutdown functions
sc.shutdownFuncs = append(sc.shutdownFuncs,
func(_ context.Context) error { return nil }, // placeholder for future service shutdowns
)
}
// EnsureAdminUser creates the admin user if it doesn't exist
func (sc *ServiceContainer) EnsureAdminUser(ctx context.Context) error {
userService, err := sc.GetUserService()
if err != nil {
return contextutils.WrapErrorf(err, "failed to get user service")
}
return userService.EnsureAdminUserExists(ctx, sc.cfg.Server.AdminUsername, sc.cfg.Server.AdminPassword)
}
// Package handlers provides HTTP request handlers for the quiz application API.
package handlers
import (
"context"
"database/sql"
"encoding/json"
"errors"
"html/template"
"math"
"net/http"
"strconv"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
// AdminHandler handles administrative HTTP requests and dashboard functionality
type AdminHandler struct {
userService services.UserServiceInterface
questionService services.QuestionServiceInterface
aiService services.AIServiceInterface
config *config.Config
templates *template.Template
learningService services.LearningServiceInterface
workerService services.WorkerServiceInterface
logger *observability.Logger
}
// NewAdminHandlerWithLogger creates a new AdminHandler with the provided services and logger.
func NewAdminHandlerWithLogger(userService services.UserServiceInterface, questionService services.QuestionServiceInterface, aiService services.AIServiceInterface, cfg *config.Config, learningService services.LearningServiceInterface, workerService services.WorkerServiceInterface, logger *observability.Logger) *AdminHandler {
return &AdminHandler{
userService: userService,
questionService: questionService,
aiService: aiService,
config: cfg,
templates: nil,
learningService: learningService,
workerService: workerService,
logger: logger,
}
}
// GetBackendAdminData returns the backend administration data as JSON
func (h *AdminHandler) GetBackendAdminData(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_backend_admin_data")
defer observability.FinishSpan(span, nil)
// Get all users for aggregate statistics
users, err := h.userService.GetAllUsers(ctx)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
HandleAppError(c, contextutils.WrapError(err, "failed to get users"))
return
}
// Calculate aggregate user statistics
userStats := calculateUserAggregateStats(ctx, users, h.learningService, h.logger)
// Get question statistics
questionStats, err := h.questionService.GetDetailedQuestionStats(ctx)
if err != nil {
h.logger.Warn(ctx, "Failed to get question stats", map[string]interface{}{"error": err.Error()})
questionStats = make(map[string]interface{})
}
// Get worker health if available
var workerHealth map[string]interface{}
if h.workerService != nil {
workerHealth, err = h.workerService.GetWorkerHealth(ctx)
if err != nil {
h.logger.Warn(ctx, "Failed to get worker health", map[string]interface{}{"error": err.Error()})
workerHealth = map[string]interface{}{
"error": "Failed to get worker health",
}
}
}
// Get AI concurrency stats
aiStatsStruct := h.aiService.GetConcurrencyStats()
aiConcurrencyStats := map[string]interface{}{
"active_requests": aiStatsStruct.ActiveRequests,
"max_concurrent": aiStatsStruct.MaxConcurrent,
"queued_requests": aiStatsStruct.QueuedRequests,
"total_requests": aiStatsStruct.TotalRequests,
"user_active_count": aiStatsStruct.UserActiveCount,
"max_per_user": aiStatsStruct.MaxPerUser,
}
data := gin.H{
"user_stats": userStats,
"question_stats": questionStats,
"worker_health": workerHealth,
"ai_concurrency_stats": aiConcurrencyStats,
"worker_port": h.config.Server.WorkerPort,
"worker_base_url": h.config.Server.WorkerBaseURL,
}
c.JSON(http.StatusOK, data)
}
// GetBackendAdminPage renders the backend administration dashboard
func (h *AdminHandler) GetBackendAdminPage(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_backend_admin_page")
defer observability.FinishSpan(span, nil)
// Get all users with progress and question stats
users, err := h.userService.GetAllUsers(ctx)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
HandleAppError(c, contextutils.WrapError(err, "failed to get users"))
return
}
type UserWithProgress struct {
User models.User
Progress *models.UserProgress
QuestionStats *services.UserQuestionStats
UserQuestionCounts map[string]interface{}
}
var usersWithProgress []UserWithProgress
for _, user := range users {
progress, err := h.learningService.GetUserProgress(ctx, user.ID)
if err != nil {
h.logger.Warn(ctx, "Failed to get progress for user", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
progress = &models.UserProgress{
CurrentLevel: "A1",
TotalQuestions: 0,
CorrectAnswers: 0,
AccuracyRate: 0,
}
}
questionStats, err := h.learningService.GetUserQuestionStats(ctx, user.ID)
if err != nil {
h.logger.Warn(ctx, "Failed to get question stats for user", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
questionStats = &services.UserQuestionStats{
UserID: user.ID,
TotalAnswered: 0,
}
}
// Get per-user question counts by type and level
userQuestionCounts := make(map[string]interface{})
// Use the available stats from UserQuestionStats
if questionStats != nil {
userQuestionCounts["total_answered"] = questionStats.TotalAnswered
userQuestionCounts["answered_by_type"] = questionStats.AnsweredByType
userQuestionCounts["answered_by_level"] = questionStats.AnsweredByLevel
userQuestionCounts["accuracy_by_type"] = questionStats.AccuracyByType
userQuestionCounts["accuracy_by_level"] = questionStats.AccuracyByLevel
userQuestionCounts["available_by_type"] = questionStats.AvailableByType
userQuestionCounts["available_by_level"] = questionStats.AvailableByLevel
}
usersWithProgress = append(usersWithProgress, UserWithProgress{
User: user,
Progress: progress,
QuestionStats: questionStats,
UserQuestionCounts: userQuestionCounts,
})
}
// Get question statistics
questionStats, err := h.questionService.GetDetailedQuestionStats(ctx)
if err != nil {
h.logger.Warn(ctx, "Failed to get question stats", map[string]interface{}{"error": err.Error()})
questionStats = make(map[string]interface{})
}
// Get worker health if available
var workerHealth map[string]interface{}
if h.workerService != nil {
workerHealth, err = h.workerService.GetWorkerHealth(ctx)
if err != nil {
h.logger.Warn(ctx, "Failed to get worker health", map[string]interface{}{"error": err.Error()})
workerHealth = map[string]interface{}{
"error": "Failed to get worker health",
}
}
}
// Get AI concurrency stats
aiStatsStruct := h.aiService.GetConcurrencyStats()
aiConcurrencyStats := map[string]interface{}{
"active_requests": aiStatsStruct.ActiveRequests,
"max_concurrent": aiStatsStruct.MaxConcurrent,
"queued_requests": aiStatsStruct.QueuedRequests,
"total_requests": aiStatsStruct.TotalRequests,
"user_active_count": aiStatsStruct.UserActiveCount,
"max_per_user": aiStatsStruct.MaxPerUser,
}
data := gin.H{
"Title": "Backend Administration",
"Users": usersWithProgress,
"QuestionStats": questionStats,
"WorkerHealth": workerHealth,
"AIConcurrencyStats": aiConcurrencyStats,
"IsBackend": true,
"WorkerPort": h.config.Server.WorkerPort,
"CurrentPage": "backend_admin",
"WorkerBaseURL": h.config.Server.WorkerBaseURL,
}
// Try to render template, fallback to JSON if template fails
if h.templates != nil {
// Add no-cache headers
c.Header("Content-Type", "text/html; charset=utf-8")
c.Header("Cache-Control", "no-cache, no-store, must-revalidate")
c.Header("Pragma", "no-cache")
c.Header("Expires", "0")
if err := h.templates.ExecuteTemplate(c.Writer, "backend_admin.html", data); err != nil {
h.logger.Error(ctx, "Template execution failed", err, map[string]interface{}{})
HandleAppError(c, contextutils.WrapError(err, "failed to render template"))
return
}
} else {
c.JSON(http.StatusOK, data)
}
}
// UserData represents user information combined with their progress data
type UserData struct {
User models.User
Progress *models.UserProgress
}
// UserDataWithQuestions represents user information with questions and responses
type UserDataWithQuestions struct {
User models.User
Progress *models.UserProgress
QuestionStats *services.UserQuestionStats
TotalQuestions int
TotalResponses int
RecentQuestions []string
Questions []*services.QuestionWithStats // Actual question objects with stats
}
// ReportedQuestionsData represents the structure for reported questions page data
type ReportedQuestionsData struct {
Users []UserDataWithQuestions
ReportedQuestions []*services.ReportedQuestionWithUser
}
// ShowDatazPage - Removed: Use frontend admin interface instead
// MarkQuestionAsFixed marks a reported question as fixed and puts it back in rotation
func (h *AdminHandler) MarkQuestionAsFixed(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
if err := h.questionService.MarkQuestionAsFixed(c.Request.Context(), questionID); err != nil {
h.logger.Error(c.Request.Context(), "Failed to mark question as fixed", err, map[string]interface{}{"question_id": questionID})
// Check if the error is due to question not found
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to mark question as fixed"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Question marked as fixed successfully"})
}
// UpdateQuestion updates a question's content, correct answer, and explanation
func (h *AdminHandler) UpdateQuestion(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
var req struct {
Content map[string]interface{} `json:"content" binding:"required"`
CorrectAnswer int `json:"correct_answer" binding:"gte=0,lte=3"`
Explanation string `json:"explanation" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request format",
"",
err,
))
return
}
// Sanitize incoming content to avoid nested `content.content` and duplicated fields.
content := req.Content
for {
if inner, ok := content["content"]; ok {
if innerMap, ok2 := inner.(map[string]interface{}); ok2 {
content = innerMap
continue
}
}
break
}
// Remove duplicate top-level keys from the content payload if present.
// Defensive cleanup while migrating to strict OpenAPI validation.
delete(content, "correct_answer")
delete(content, "explanation")
delete(content, "change_reason")
// Ensure options is not nil (convert null -> empty slice)
if opts, exists := content["options"]; !exists || opts == nil {
content["options"] = []string{}
}
if err := h.questionService.UpdateQuestion(c.Request.Context(), questionID, content, req.CorrectAnswer, req.Explanation); err != nil {
h.logger.Error(c.Request.Context(), "Failed to update question", err, map[string]interface{}{"question_id": questionID})
// Check if the error is due to question not found
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to update question"))
return
}
// If requested, mark the question as fixed and clear reports
if strings.ToLower(c.Query("mark_fixed")) == "true" {
ctx := c.Request.Context()
// Mark as fixed (sets status to active)
if err := h.questionService.MarkQuestionAsFixed(ctx, questionID); err != nil {
h.logger.Error(ctx, "Failed to mark question as fixed after update", err, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.WrapError(err, "failed to mark question as fixed"))
return
}
// Clear question reports
db := h.questionService.DB()
if _, err := db.ExecContext(ctx, `DELETE FROM question_reports WHERE question_id = $1`, questionID); err != nil {
h.logger.Warn(ctx, "Failed to clear question reports", map[string]interface{}{"question_id": questionID, "error": err.Error()})
}
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Question updated successfully"})
}
// FixQuestionWithAI uses AI to suggest fixes for a problematic question
func (h *AdminHandler) FixQuestionWithAI(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Get the original question
question, err := h.questionService.GetQuestionByID(c.Request.Context(), questionID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get question", err, map[string]interface{}{"question_id": questionID})
// Check if the error is due to question not found
if errors.Is(err, sql.ErrNoRows) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get question"))
return
}
// Find reporter(s) and choose a configured AI provider/model from the reporting user(s)
ctx := c.Request.Context()
db := h.questionService.DB()
rows, err := db.QueryContext(ctx, `SELECT u.id, u.username, u.ai_provider, u.ai_model, qr.report_reason FROM question_reports qr JOIN users u ON qr.reported_by_user_id = u.id WHERE qr.question_id = $1 ORDER BY qr.created_at ASC`, questionID)
if err != nil {
h.logger.Error(ctx, "Failed to query question reports", err, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.WrapError(err, "failed to get report details"))
return
}
if err := rows.Err(); err != nil {
h.logger.Warn(ctx, "rows iteration error before defer", map[string]interface{}{"error": err.Error(), "question_id": questionID})
}
defer func() {
if err := rows.Close(); err != nil {
h.logger.Warn(ctx, "Failed to close report rows", map[string]interface{}{"error": err.Error(), "question_id": questionID})
}
}()
var reporterID int
var reporterUsername string
var reporterProvider sql.NullString
var reporterModel sql.NullString
var singleReason sql.NullString
foundProvider := false
for rows.Next() {
var uid int
var uname string
var prov sql.NullString
var mod sql.NullString
var reason sql.NullString
if err := rows.Scan(&uid, &uname, &prov, &mod, &reason); err != nil {
h.logger.Warn(ctx, "Failed to scan report row", map[string]interface{}{"error": err.Error(), "question_id": questionID})
continue
}
// Prefer the first reporter that has an AI provider+model configured
if prov.Valid && prov.String != "" && mod.Valid && mod.String != "" {
reporterID = uid
reporterUsername = uname
reporterProvider = prov
reporterModel = mod
singleReason = reason
foundProvider = true
break
}
// Keep the first reporter as fallback (no provider)
if reporterID == 0 {
reporterID = uid
reporterUsername = uname
reporterProvider = prov
reporterModel = mod
singleReason = reason
}
}
if !foundProvider {
// If no reporting user has AI configured, fall back to admin user's AI settings or global default provider
h.logger.Info(ctx, "No reporting user has AI configured; attempting fallback to admin or global provider", map[string]interface{}{"question_id": questionID})
// Try to get current admin user from context/session
var adminUserID int
if uid, err := GetCurrentUserID(c); err == nil {
adminUserID = uid
}
// Try admin user's configured provider/model
if adminUserID != 0 {
adminUser, err := h.userService.GetUserByID(ctx, adminUserID)
if err == nil && adminUser != nil && adminUser.AIProvider.Valid && adminUser.AIProvider.String != "" && adminUser.AIModel.Valid && adminUser.AIModel.String != "" {
reporterID = adminUser.ID
reporterUsername = adminUser.Username
reporterProvider = adminUser.AIProvider
reporterModel = adminUser.AIModel
foundProvider = true
h.logger.Info(ctx, "Falling back to admin user's AI provider", map[string]interface{}{"admin_id": adminUserID, "provider": adminUser.AIProvider.String, "model": adminUser.AIModel.String})
}
}
// If still not found, try global config first provider
if !foundProvider && h.config != nil && len(h.config.Providers) > 0 {
p := h.config.Providers[0]
if len(p.Models) > 0 {
// Use first provider and model from global config
reporterProvider = sql.NullString{String: p.Code, Valid: true}
reporterModel = sql.NullString{String: p.Models[0].Code, Valid: true}
reporterUsername = "system"
foundProvider = true
h.logger.Info(ctx, "Falling back to global configured AI provider", map[string]interface{}{"provider": p.Code, "model": p.Models[0].Code})
}
}
if !foundProvider {
h.logger.Warn(ctx, "No AI provider configured for reporting users and no fallback available", map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.ErrAIConfigInvalid)
return
}
}
// Get saved API key for the reporter's configured provider
savedKey, _ := h.userService.GetUserAPIKey(ctx, reporterID, reporterProvider.String)
userCfg := &services.UserAIConfig{
Provider: reporterProvider.String,
Model: reporterModel.String,
APIKey: savedKey,
Username: reporterUsername,
}
// Build AI chat request with question details and report reasons
// Use the template manager to render a structured prompt
// Prepare template data
questionContentJSON, _ := question.MarshalContentToJSON()
// Resolve schema for prompt; fail if none
schema, err := services.GetFixSchema(question.Type)
if err != nil {
h.logger.Error(ctx, "No schema available for question type", err, map[string]interface{}{"question_id": questionID, "type": question.Type})
HandleAppError(c, contextutils.ErrAIConfigInvalid)
return
}
// Read optional additional_context from POST body JSON
var body struct {
AdditionalContext string `json:"additional_context"`
}
_ = c.BindJSON(&body) // ignore error; body may be empty
tmplData := services.AITemplateData{
CurrentQuestionJSON: questionContentJSON,
ExampleContent: "", // will be filled below if example available
SchemaForPrompt: schema,
ReportReasons: []string{},
AdditionalContext: body.AdditionalContext,
}
if singleReason.Valid {
tmplData.ReportReasons = []string{singleReason.String}
}
// Load example for this question type if available
if ex, err := h.aiService.TemplateManager().LoadExample(string(question.Type)); err == nil {
tmplData.ExampleContent = ex
}
prompt, err := h.aiService.TemplateManager().RenderTemplate(services.AIFixPromptTemplate, tmplData)
if err != nil {
h.logger.Error(ctx, "Failed to render AI fix prompt", err, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.WrapError(err, "failed to build AI prompt"))
return
}
// Use schema as grammar for providers that support it
supportsGrammar := h.aiService.SupportsGrammarField(userCfg.Provider)
var grammar string
if supportsGrammar {
grammar, err = services.GetFixSchema(question.Type)
if err != nil {
h.logger.Error(ctx, "No grammar schema available for question type", err, map[string]interface{}{"question_id": questionID, "type": question.Type})
HandleAppError(c, contextutils.ErrAIConfigInvalid)
return
}
} else {
grammar = ""
}
// Call AI service with constructed prompt and grammar
respStr, err := h.aiService.CallWithPrompt(ctx, userCfg, prompt, grammar)
if err != nil {
h.logger.Error(ctx, "AI service call failed", err, map[string]interface{}{"question_id": questionID, "provider": userCfg.Provider})
HandleAppError(c, contextutils.WrapError(err, "AI service error"))
return
}
// Attempt to parse AI response as JSON (and try to recover JSON substring if necessary)
var aiResp map[string]interface{}
if err := json.Unmarshal([]byte(respStr), &aiResp); err != nil {
start := strings.Index(respStr, "{")
end := strings.LastIndex(respStr, "}")
if start >= 0 && end > start {
candidate := respStr[start : end+1]
if err2 := json.Unmarshal([]byte(candidate), &aiResp); err2 != nil {
h.logger.Error(ctx, "Failed to parse AI response as JSON", err2, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.ErrAIResponseInvalid)
return
}
} else {
h.logger.Error(ctx, "AI did not return JSON", nil, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.ErrAIResponseInvalid)
return
}
}
// Start from the original question map so required top-level fields are preserved
originalMap := map[string]interface{}{}
if b, err := json.Marshal(question); err == nil {
_ = json.Unmarshal(b, &originalMap)
}
// Use helper to merge and normalize AI suggestion into original map
suggestion := MergeAISuggestion(originalMap, aiResp)
// Attach admin-provided additional context into suggestion metadata so frontend can display it
if body.AdditionalContext != "" {
suggestion["additional_context"] = body.AdditionalContext
}
// If query param apply=true present, apply suggestion directly and mark fixed
if strings.ToLower(c.Query("apply")) == "true" {
// Build update payload: use merged content
updateContent := suggestion["content"].(map[string]interface{})
// Extract correct_answer as int (support float64 from JSON)
correctAnswer := 0
if ca, ok := updateContent["correct_answer"]; ok {
switch v := ca.(type) {
case float64:
correctAnswer = int(v)
case int:
correctAnswer = v
}
}
explanation := ""
if ex, ok := updateContent["explanation"].(string); ok {
explanation = ex
}
if err := h.questionService.UpdateQuestion(c.Request.Context(), questionID, updateContent, correctAnswer, explanation); err != nil {
h.logger.Error(c.Request.Context(), "Failed to update question with AI suggestion", err, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.WrapError(err, "failed to apply suggestion"))
return
}
if err := h.questionService.MarkQuestionAsFixed(c.Request.Context(), questionID); err != nil {
h.logger.Warn(c.Request.Context(), "Failed to mark question as fixed after applying suggestion", map[string]interface{}{"question_id": questionID, "error": err.Error()})
}
db := h.questionService.DB()
if _, err := db.ExecContext(c.Request.Context(), `DELETE FROM question_reports WHERE question_id = $1`, questionID); err != nil {
h.logger.Warn(c.Request.Context(), "Failed to clear question reports after applying suggestion", map[string]interface{}{"question_id": questionID, "error": err.Error()})
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Suggestion applied"})
return
}
// Return original question and merged AI suggestion for frontend review
c.JSON(http.StatusOK, gin.H{
"original": question,
"suggestion": suggestion,
})
}
// ServeDatazJS - Removed: Use frontend admin interface instead
// GetAIConcurrencyStats returns AI service concurrency metrics
func (h *AdminHandler) GetAIConcurrencyStats(c *gin.Context) {
// Get stats from the local AI service instance
stats := h.aiService.GetConcurrencyStats()
c.JSON(http.StatusOK, gin.H{
"ai_concurrency": stats,
})
}
// ClearUserData removes all user activity data but keeps the users themselves
func (h *AdminHandler) ClearUserData(c *gin.Context) {
err := h.userService.ClearUserData(c.Request.Context())
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to clear user data", err, map[string]interface{}{})
HandleAppError(c, contextutils.WrapError(err, "failed to clear user data"))
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "User data cleared successfully (users preserved)"})
}
// ClearDatabase completely resets the database to an empty state
func (h *AdminHandler) ClearDatabase(c *gin.Context) {
err := h.userService.ResetDatabase(c.Request.Context())
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to clear database", err, map[string]interface{}{})
HandleAppError(c, contextutils.WrapError(err, "failed to clear database"))
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Database cleared successfully"})
}
// GetQuestion returns a single question by ID for editing
func (h *AdminHandler) GetQuestion(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
question, err := h.questionService.GetQuestionByID(c.Request.Context(), questionID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get question", err, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
c.JSON(http.StatusOK, question)
}
// GetUsersForQuestion returns the users assigned to a question
func (h *AdminHandler) GetUsersForQuestion(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
users, totalCount, err := h.questionService.GetUsersForQuestion(c.Request.Context(), questionID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get users for question", err, map[string]interface{}{"question_id": questionID})
HandleAppError(c, contextutils.WrapError(err, "failed to get users for question"))
return
}
c.JSON(http.StatusOK, gin.H{
"users": users,
"total_count": totalCount,
})
}
// AssignUsersToQuestion assigns multiple users to a question
func (h *AdminHandler) AssignUsersToQuestion(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
var request struct {
UserIDs []int `json:"user_ids" binding:"required"`
}
if err := c.ShouldBindJSON(&request); err != nil {
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
// Validate non-empty user list
if len(request.UserIDs) == 0 {
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
// Check if the question exists first
_, err = h.questionService.GetQuestionByID(c.Request.Context(), questionID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get question", err, map[string]interface{}{"question_id": questionID})
// Check if the error is due to question not found
if errors.Is(err, sql.ErrNoRows) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get question"))
return
}
err = h.questionService.AssignUsersToQuestion(c.Request.Context(), questionID, request.UserIDs)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to assign users to question", err, map[string]interface{}{
"question_id": questionID,
"user_ids": request.UserIDs,
})
HandleAppError(c, contextutils.WrapError(err, "failed to assign users to question"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Users assigned to question successfully"})
}
// UnassignUsersFromQuestion removes multiple users from a question
func (h *AdminHandler) UnassignUsersFromQuestion(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
var request struct {
UserIDs []int `json:"user_ids" binding:"required"`
}
if err := c.ShouldBindJSON(&request); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(contextutils.ErrorCodeInvalidInput, contextutils.SeverityWarn, "Invalid request body", "", err))
return
}
// Validate non-empty user list
if len(request.UserIDs) == 0 {
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
// Check if the question exists first
_, err = h.questionService.GetQuestionByID(c.Request.Context(), questionID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get question", err, map[string]interface{}{"question_id": questionID})
// Check if the error is due to question not found
if errors.Is(err, sql.ErrNoRows) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get question"))
return
}
err = h.questionService.UnassignUsersFromQuestion(c.Request.Context(), questionID, request.UserIDs)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to unassign users from question", err, map[string]interface{}{
"question_id": questionID,
"user_ids": request.UserIDs,
})
HandleAppError(c, contextutils.WrapError(err, "failed to unassign users from question"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Users unassigned from question successfully"})
}
// DeleteQuestion deletes a question by ID
func (h *AdminHandler) DeleteQuestion(c *gin.Context) {
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
err = h.questionService.DeleteQuestion(c.Request.Context(), questionID)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to delete question", err, map[string]interface{}{"question_id": questionID})
// Check if the error is due to question not found
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to delete question"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Question deleted successfully"})
}
// GetQuestionsPaginated returns paginated questions with response statistics
func (h *AdminHandler) GetQuestionsPaginated(c *gin.Context) {
userIDStr := c.Query("user_id")
if userIDStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Parse pagination and filters
page, pageSize := ParsePagination(c, 1, 10, 100)
filters := ParseFilters(c, "search", "type", "status")
search := filters["search"]
typeFilter := filters["type"]
statusFilter := filters["status"]
// Get questions with filters
questions, total, err := h.questionService.GetQuestionsPaginated(
c.Request.Context(),
userID,
page,
pageSize,
search,
typeFilter,
statusFilter,
)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get paginated questions", err, map[string]interface{}{
"user_id": userID,
"page": page,
"size": pageSize,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get questions"))
return
}
c.JSON(http.StatusOK, gin.H{
"questions": func() []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(questions))
for _, q := range questions {
out = append(out, convertQuestionWithStatsToAPIMap(q))
}
return out
}(),
"pagination": gin.H{
"page": page,
"page_size": pageSize,
"total": total,
"total_pages": int(math.Ceil(float64(total) / float64(pageSize))),
},
})
}
// GetAllQuestions returns all questions with pagination and filtering
func (h *AdminHandler) GetAllQuestions(c *gin.Context) {
// Parse pagination and filters
page, pageSize := ParsePagination(c, 1, 20, 100)
f := ParseFilters(c, "search", "type", "status", "language", "level")
search := f["search"]
typeFilter := f["type"]
statusFilter := f["status"]
languageFilter := f["language"]
levelFilter := f["level"]
userIDStr := c.Query("user_id")
// Parse user_id if provided
var userID *int
if userIDStr != "" {
uid, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
userID = &uid
}
// Get questions with filters
questions, total, err := h.questionService.GetAllQuestionsPaginated(
c.Request.Context(),
page,
pageSize,
search,
typeFilter,
statusFilter,
languageFilter,
levelFilter,
userID,
)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get all questions", err, map[string]interface{}{
"page": page,
"size": pageSize,
"search": search,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get questions"))
return
}
// Get stats
stats, err := h.questionService.GetQuestionStats(c.Request.Context())
if err != nil {
h.logger.Warn(c.Request.Context(), "Failed to get question stats", map[string]interface{}{"error": err.Error()})
stats = map[string]interface{}{}
}
c.JSON(http.StatusOK, gin.H{
"questions": func() []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(questions))
for _, q := range questions {
out = append(out, convertQuestionWithStatsToAPIMap(q))
}
return out
}(),
"pagination": gin.H{
"page": page,
"page_size": pageSize,
"total": total,
"total_pages": int(math.Ceil(float64(total) / float64(pageSize))),
},
"stats": stats,
})
}
// GetReportedQuestionsPaginated returns reported questions with pagination and filtering
func (h *AdminHandler) GetReportedQuestionsPaginated(c *gin.Context) {
// Parse pagination and filters
page, pageSize := ParsePagination(c, 1, 20, 100)
f := ParseFilters(c, "search", "type", "language", "level")
search := f["search"]
typeFilter := f["type"]
languageFilter := f["language"]
levelFilter := f["level"]
// Get reported questions with filters
questions, total, err := h.questionService.GetReportedQuestionsPaginated(
c.Request.Context(),
page,
pageSize,
search,
typeFilter,
languageFilter,
levelFilter,
)
if err != nil {
h.logger.Error(c.Request.Context(), "Failed to get reported questions", err, map[string]interface{}{
"page": page,
"size": pageSize,
"search": search,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get reported questions"))
return
}
// Get reported questions stats
stats, err := h.questionService.GetReportedQuestionsStats(c.Request.Context())
if err != nil {
h.logger.Warn(c.Request.Context(), "Failed to get reported questions stats", map[string]interface{}{"error": err.Error()})
stats = map[string]interface{}{}
}
c.JSON(http.StatusOK, gin.H{
"questions": func() []map[string]interface{} {
out := make([]map[string]interface{}, 0, len(questions))
for _, q := range questions {
out = append(out, convertQuestionWithStatsToAPIMap(q))
}
return out
}(),
"pagination": gin.H{
"page": page,
"page_size": pageSize,
"total": total,
"total_pages": int(math.Ceil(float64(total) / float64(pageSize))),
},
"stats": stats,
})
}
// ClearUserDataForUser removes all user activity data for a specific user but keeps the user record
func (h *AdminHandler) ClearUserDataForUser(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "clear_user_data_for_user")
defer observability.FinishSpan(span, nil)
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists before attempting to clear data
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for clear data operation", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
err = h.userService.ClearUserDataForUser(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to clear user data for user", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to clear user data for user"))
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "User data cleared successfully (user preserved)"})
}
// GetConfigz returns the merged config as pretty-printed JSON
func (h *AdminHandler) GetConfigz(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_configz")
defer observability.FinishSpan(span, nil)
c.IndentedJSON(http.StatusOK, h.config)
}
// GetRoles returns all available roles in the system
func (h *AdminHandler) GetRoles(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_roles")
defer observability.FinishSpan(span, nil)
// For now, return hardcoded roles since we don't have a role service
// In a real implementation, you'd query the database
roles := []models.Role{
{ID: 1, Name: "user", Description: "Normal site access", CreatedAt: time.Now(), UpdatedAt: time.Now()},
{ID: 2, Name: "admin", Description: "Administrative access to all features", CreatedAt: time.Now(), UpdatedAt: time.Now()},
}
c.JSON(http.StatusOK, gin.H{"roles": roles})
}
// GetUserRoles returns all roles for a specific user
func (h *AdminHandler) GetUserRoles(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_user_roles")
defer observability.FinishSpan(span, nil)
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists before getting roles
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for roles operation", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
roles, err := h.userService.GetUserRoles(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user roles", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user roles"))
return
}
c.JSON(http.StatusOK, gin.H{"roles": roles})
}
// AssignRole assigns a role to a user
func (h *AdminHandler) AssignRole(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "assign_role")
defer observability.FinishSpan(span, nil)
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists before assigning role
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for role assignment", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
var req struct {
RoleID int `json:"role_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(contextutils.ErrorCodeInvalidInput, contextutils.SeverityWarn, "Invalid request body", "", err))
return
}
// Ensure the requester is allowed (self or admin). Route is admin-only, but keep explicit check.
currentUserID, err := GetCurrentUserID(c)
if err == nil {
if err := RequireSelfOrAdmin(ctx, h.userService, currentUserID, userID); err != nil {
if errors.Is(err, ErrForbidden) {
HandleAppError(c, contextutils.ErrForbidden)
return
}
h.logger.Error(ctx, "Failed to check authorization", err, map[string]interface{}{"user_id": currentUserID})
HandleAppError(c, contextutils.WrapError(err, "failed to check authorization"))
return
}
}
err = h.userService.AssignRole(ctx, userID, req.RoleID)
if err != nil {
h.logger.Error(ctx, "Failed to assign role to user", err, map[string]interface{}{"user_id": userID, "role_id": req.RoleID})
HandleAppError(c, contextutils.WrapError(err, "failed to assign role"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Role assigned successfully"})
}
// RemoveRole removes a role from a user
func (h *AdminHandler) RemoveRole(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "remove_role")
defer observability.FinishSpan(span, nil)
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists before removing role
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for role removal", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
roleIDStr := c.Param("roleId")
roleID, err := strconv.Atoi(roleIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Ensure the requester is allowed (self or admin). Route is admin-only, but keep explicit check.
currentUserID, err := GetCurrentUserID(c)
if err == nil {
if err := RequireSelfOrAdmin(ctx, h.userService, currentUserID, userID); err != nil {
if errors.Is(err, ErrForbidden) {
HandleAppError(c, contextutils.ErrForbidden)
return
}
h.logger.Error(ctx, "Failed to check authorization", err, map[string]interface{}{"user_id": currentUserID})
HandleAppError(c, contextutils.WrapError(err, "failed to check authorization"))
return
}
}
err = h.userService.RemoveRole(ctx, userID, roleID)
if err != nil {
h.logger.Error(ctx, "Failed to remove role", err, map[string]interface{}{"user_id": userID, "role_id": roleID})
// Check if it's a "user does not have role" error
if strings.Contains(err.Error(), "does not have role") {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Check if it's a "user not found" or "role not found" error
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to remove role"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "Role removed successfully"})
}
// calculateUserAggregateStats calculates aggregate statistics for all users
func calculateUserAggregateStats(ctx context.Context, users []models.User, learningService services.LearningServiceInterface, logger *observability.Logger) map[string]interface{} {
stats := map[string]interface{}{
"total_users": len(users),
"by_language": make(map[string]int),
"by_level": make(map[string]int),
"by_ai_provider": make(map[string]int),
"by_ai_model": make(map[string]int),
"ai_enabled": 0,
"ai_disabled": 0,
"active_users": 0,
"inactive_users": 0,
"total_questions_answered": 0,
"total_correct_answers": 0,
"average_accuracy": 0.0,
}
activeThreshold := time.Now().AddDate(0, 0, -7)
for _, user := range users {
lang := "unknown"
if user.PreferredLanguage.Valid {
lang = user.PreferredLanguage.String
}
stats["by_language"].(map[string]int)[lang]++
level := "unknown"
if user.CurrentLevel.Valid {
level = user.CurrentLevel.String
}
stats["by_level"].(map[string]int)[level]++
provider := "none"
if user.AIProvider.Valid {
provider = user.AIProvider.String
}
stats["by_ai_provider"].(map[string]int)[provider]++
model := "none"
if user.AIModel.Valid {
model = user.AIModel.String
}
stats["by_ai_model"].(map[string]int)[model]++
if user.AIEnabled.Valid && user.AIEnabled.Bool {
aiEnabled := stats["ai_enabled"].(int)
stats["ai_enabled"] = aiEnabled + 1
} else {
aiDisabled := stats["ai_disabled"].(int)
stats["ai_disabled"] = aiDisabled + 1
}
if user.LastActive.Valid {
lastActive := user.LastActive.Time
if lastActive.After(activeThreshold) {
activeUsers := stats["active_users"].(int)
stats["active_users"] = activeUsers + 1
} else {
inactiveUsers := stats["inactive_users"].(int)
stats["inactive_users"] = inactiveUsers + 1
}
} else {
inactiveUsers := stats["inactive_users"].(int)
stats["inactive_users"] = inactiveUsers + 1
}
progress, err := learningService.GetUserProgress(ctx, user.ID)
if err != nil {
logger.Warn(ctx, "Failed to get progress for user", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
continue
}
if progress != nil {
totalAnswered := stats["total_questions_answered"].(int)
stats["total_questions_answered"] = totalAnswered + progress.TotalQuestions
totalCorrect := stats["total_correct_answers"].(int)
stats["total_correct_answers"] = totalCorrect + progress.CorrectAnswers
}
}
totalAnswered := stats["total_questions_answered"].(int)
if totalAnswered > 0 {
stats["average_accuracy"] = float64(stats["total_correct_answers"].(int)) / float64(totalAnswered) * 100.0
}
return stats
}
package handlers
import (
"encoding/json"
"fmt"
"strconv"
"strings"
)
// MergeAISuggestion merges AI response into the original question map.
// It ensures top-level metadata from original are preserved and AI-provided
// content is merged into original["content"]. It moves top-level correct_answer
// and explanation into content to avoid duplicates.
func MergeAISuggestion(original, aiResp map[string]interface{}) map[string]interface{} {
// copy original to avoid mutating caller's map
out := map[string]interface{}{}
b, _ := json.Marshal(original)
_ = json.Unmarshal(b, &out)
// ensure content map exists
contentIface := out["content"]
contentMap, _ := contentIface.(map[string]interface{})
if contentMap == nil {
contentMap = map[string]interface{}{}
out["content"] = contentMap
}
// merge ai content
if aiContentRaw, ok := aiResp["content"]; ok {
if aiContentMap, ok2 := aiContentRaw.(map[string]interface{}); ok2 {
for k, v := range aiContentMap {
contentMap[k] = v
}
}
}
// move top-level fields into content
if ca, ok := aiResp["correct_answer"]; ok {
contentMap["correct_answer"] = ca
delete(aiResp, "correct_answer")
}
if ex, ok := aiResp["explanation"]; ok {
contentMap["explanation"] = ex
delete(aiResp, "explanation")
}
if cr, ok := aiResp["change_reason"]; ok {
out["change_reason"] = cr
}
NormalizeContent(contentMap)
return out
}
// NormalizeContent attempts to sanitize content fields: options->[]string,
// correct_answer->int, trims duplicates and clamps indices.
func NormalizeContent(contentMap map[string]interface{}) {
// normalize options
if optsRaw, ok := contentMap["options"]; ok {
switch opts := optsRaw.(type) {
case []interface{}:
seen := map[string]bool{}
var out []string
for _, it := range opts {
s, ok := it.(string)
if !ok {
continue
}
s = strings.TrimSpace(s)
if s == "" {
continue
}
if !seen[s] {
out = append(out, s)
seen[s] = true
}
}
contentMap["options"] = out
case []string:
// ok
case string:
var parsed []string
if err := json.Unmarshal([]byte(opts), &parsed); err == nil {
contentMap["options"] = parsed
} else {
parts := strings.FieldsFunc(opts, func(r rune) bool { return r == '\n' || r == ',' })
var out []string
seen := map[string]bool{}
for _, p := range parts {
p = strings.TrimSpace(p)
if p == "" {
continue
}
if !seen[p] {
out = append(out, p)
seen[p] = true
}
}
contentMap["options"] = out
}
default:
delete(contentMap, "options")
}
}
// ensure options slice is []string
if optsI, ok := contentMap["options"].([]interface{}); ok {
var out []string
for _, it := range optsI {
if s, ok := it.(string); ok {
out = append(out, s)
}
}
contentMap["options"] = out
}
// normalize correct_answer
if ca, ok := contentMap["correct_answer"]; ok {
switch v := ca.(type) {
case float64:
contentMap["correct_answer"] = int(v)
case int:
// ok
case string:
if n, err := strconv.Atoi(strings.TrimSpace(v)); err == nil {
contentMap["correct_answer"] = n
} else {
delete(contentMap, "correct_answer")
}
default:
delete(contentMap, "correct_answer")
}
}
// clamp correct_answer to options length
if ca, ok := contentMap["correct_answer"].(int); ok {
if opts, ok := contentMap["options"].([]string); ok {
if len(opts) == 0 {
contentMap["correct_answer"] = 0
} else if ca < 0 || ca >= len(opts) {
contentMap["correct_answer"] = 0
}
}
}
// ensure simple string fields
for _, k := range []string{"explanation", "question", "passage", "sentence"} {
if v, ok := contentMap[k]; ok {
switch t := v.(type) {
case string:
// ok
default:
contentMap[k] = fmt.Sprint(t)
}
}
}
}
package handlers
import (
"crypto/rand"
"errors"
"net/http"
"regexp"
"strings"
"time"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/middleware"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
openapi_types "github.com/oapi-codegen/runtime/types"
"go.opentelemetry.io/otel/attribute"
)
// AuthHandler handles authentication related HTTP requests
type AuthHandler struct {
userService services.UserServiceInterface
oauthService *services.OAuthService
config *config.Config
logger *observability.Logger
}
// NewAuthHandler creates a new AuthHandler instance
func NewAuthHandler(userService services.UserServiceInterface, oauthService *services.OAuthService, cfg *config.Config, logger *observability.Logger) *AuthHandler {
return &AuthHandler{
userService: userService,
oauthService: oauthService,
config: cfg,
logger: logger,
}
}
// Login handles user login requests
func (h *AuthHandler) Login(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "login")
defer observability.FinishSpan(span, nil)
var req api.LoginRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Set span attributes for observability
span.SetAttributes(
attribute.String("auth.username", req.Username),
attribute.Bool("auth.password_provided", req.Password != ""),
)
// Authenticate user against database
user, err := h.userService.AuthenticateUser(c.Request.Context(), req.Username, req.Password)
if err != nil {
h.logger.Error(c.Request.Context(), "Authentication failed for user", err, map[string]interface{}{"username": req.Username})
HandleAppError(c, contextutils.ErrInvalidCredentials)
return
}
if user == nil {
HandleAppError(c, contextutils.ErrInvalidCredentials)
return
}
// Update span attributes with user info
span.SetAttributes(
attribute.Int("user.id", user.ID),
attribute.String("user.username", user.Username),
attribute.Bool("user.email_provided", user.Email.Valid),
attribute.String("user.language", user.PreferredLanguage.String),
attribute.String("user.level", user.CurrentLevel.String),
)
// Update last active
if err := h.userService.UpdateLastActive(c.Request.Context(), user.ID); err != nil {
// Log error but don't fail login
// In production, you'd want proper logging here
h.logger.Warn(c.Request.Context(), "Failed to update last active for user", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
}
// Create session
session := sessions.Default(c)
session.Set(middleware.UserIDKey, user.ID)
session.Set(middleware.UsernameKey, user.Username)
if err := session.Save(); err != nil {
h.logger.Error(c.Request.Context(), "Failed to save session", err, map[string]interface{}{"error": err.Error()})
HandleAppError(c, contextutils.WrapError(err, "failed to create session"))
return
}
// Convert models.User to api.User with proper API key checking
apiUser := convertUserToAPIWithService(c.Request.Context(), user, h.userService)
// Return user info (without API key)
c.JSON(http.StatusOK, api.LoginResponse{
Success: boolPtr(true),
Message: stringPtr("Login successful"),
User: &apiUser,
})
}
// Logout handles user logout requests
func (h *AuthHandler) Logout(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "logout")
defer observability.FinishSpan(span, nil)
// Get user info before clearing session for tracing
session := sessions.Default(c)
userID := session.Get(middleware.UserIDKey)
username := session.Get(middleware.UsernameKey)
// Set span attributes
if userID != nil {
span.SetAttributes(attribute.Int("user.id", userID.(int)))
}
if username != nil {
span.SetAttributes(attribute.String("user.username", username.(string)))
}
session.Clear()
if err := session.Save(); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to clear session"))
return
}
c.JSON(http.StatusOK, api.SuccessResponse{
Success: true,
Message: stringPtr("Logout successful"),
})
}
// Status returns the current authentication status
func (h *AuthHandler) Status(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "status")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID := session.Get(middleware.UserIDKey)
if userID == nil {
span.SetAttributes(attribute.Bool("auth.authenticated", false))
c.JSON(http.StatusOK, gin.H{
"authenticated": false,
"user": nil,
})
return
}
span.SetAttributes(
attribute.Bool("auth.authenticated", true),
attribute.Int("user.id", userID.(int)),
)
user, err := h.userService.GetUserByID(c.Request.Context(), userID.(int))
if err != nil {
h.logger.Error(c.Request.Context(), "Error getting user by ID", err, map[string]interface{}{"user_id": userID.(int)})
HandleAppError(c, contextutils.ErrInternalError)
return
}
if user == nil {
// User not found, clear session
session.Clear()
if err := session.Save(); err != nil {
h.logger.Error(c.Request.Context(), "Error saving session", err, map[string]interface{}{"error": err.Error()})
}
span.SetAttributes(attribute.Bool("auth.user_found", false))
c.JSON(http.StatusOK, gin.H{
"authenticated": false,
"user": nil,
})
return
}
// Update span attributes with user info
span.SetAttributes(
attribute.Bool("auth.user_found", true),
attribute.String("user.username", user.Username),
attribute.Bool("user.email_provided", user.Email.Valid),
attribute.String("user.language", user.PreferredLanguage.String),
attribute.String("user.level", user.CurrentLevel.String),
attribute.Bool("user.ai_enabled", user.AIEnabled.Bool),
attribute.String("user.ai_provider", user.AIProvider.String),
attribute.String("user.ai_model", user.AIModel.String),
)
// Update last active timestamp
if err := h.userService.UpdateLastActive(c.Request.Context(), user.ID); err != nil {
h.logger.Error(c.Request.Context(), "Error updating last active", err, map[string]interface{}{"user_id": user.ID})
// Don't fail the request for this error
}
// Convert models.User to api.User with proper API key checking
apiUser := convertUserToAPIWithService(c.Request.Context(), user, h.userService)
c.JSON(http.StatusOK, gin.H{
"authenticated": true,
"user": &apiUser,
})
}
// Check is a lightweight auth-check endpoint intended for reverse proxy auth_request.
// It requires authentication via middleware and returns 204 when authenticated.
// Unauthenticated requests are rejected by the RequireAuth middleware with 401.
func (h *AuthHandler) Check(c *gin.Context) {
// If we reached here, authentication succeeded in middleware
c.Status(http.StatusNoContent)
}
// Signup handles user registration requests
func (h *AuthHandler) Signup(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "signup")
defer observability.FinishSpan(span, nil)
// Check if signups are disabled
if h.config != nil && h.config.IsSignupDisabled() {
span.SetAttributes(attribute.Bool("auth.signups_disabled", true))
HandleAppError(c, contextutils.ErrForbidden)
return
}
span.SetAttributes(attribute.Bool("auth.signups_disabled", false))
var req api.UserCreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
if errors.Is(err, openapi_types.ErrValidationEmail) {
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Set span attributes for request data
span.SetAttributes(
attribute.String("signup.username", req.Username),
attribute.Bool("signup.password_provided", req.Password != ""),
attribute.Bool("signup.email_provided", req.Email != nil && *req.Email != ""),
attribute.Bool("signup.language_provided", req.PreferredLanguage != nil && *req.PreferredLanguage != ""),
attribute.Bool("signup.level_provided", req.CurrentLevel != nil && *req.CurrentLevel != ""),
attribute.Bool("signup.timezone_provided", req.Timezone != nil && *req.Timezone != ""),
)
// Validate required fields
if req.Username == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
if req.Password == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
if req.Email == nil || *req.Email == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Validate username format (3-50 characters, alphanumeric + underscore)
if len(req.Username) < 3 || len(req.Username) > 50 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
usernameRegex := regexp.MustCompile(`^[a-zA-Z0-9_]+$`)
if !usernameRegex.MatchString(req.Username) {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Validate password (minimum 8 characters)
if len(req.Password) < 8 {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Validate email format (convert to string)
if !contextutils.IsValidEmail(string(*req.Email)) {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Normalize email to lowercase
email := strings.ToLower(string(*req.Email))
h.logger.Info(c.Request.Context(), "Attempting signup for user", map[string]interface{}{"username": req.Username, "email": email})
// Check if username already exists
existingUser, err := h.userService.GetUserByUsername(c.Request.Context(), req.Username)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking username uniqueness", err, map[string]interface{}{"username": req.Username})
HandleAppError(c, contextutils.ErrInternalError)
return
}
if existingUser != nil {
span.SetAttributes(attribute.Bool("signup.username_exists", true))
HandleAppError(c, contextutils.ErrRecordExists)
return
}
// Check if email already exists
existingUserByEmail, err := h.userService.GetUserByEmail(c.Request.Context(), email)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking email uniqueness", err, map[string]interface{}{"email": email})
HandleAppError(c, contextutils.ErrInternalError)
return
}
if existingUserByEmail != nil {
span.SetAttributes(attribute.Bool("signup.email_exists", true))
HandleAppError(c, contextutils.ErrRecordExists)
return
}
// Set default values for optional fields
language := "italian" // Default to first language in the list
if h.config != nil {
// Get available languages from config
languages := h.config.GetLanguages()
if len(languages) > 0 {
language = languages[0]
}
}
if req.PreferredLanguage != nil && *req.PreferredLanguage != "" {
language = *req.PreferredLanguage
}
// Choose canonical default level for the selected language (first level in config)
level := ""
levels := []string{}
if h.config != nil {
levels = h.config.GetLevelsForLanguage(language)
if len(levels) > 0 {
level = levels[0]
}
}
// If client provided a level, require it to be a canonical code for the language.
if req.CurrentLevel != nil && *req.CurrentLevel != "" {
provided := *req.CurrentLevel
matched := false
for _, l := range levels {
if strings.EqualFold(l, provided) {
level = l
matched = true
break
}
}
if !matched {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
}
timezone := "UTC" // Default timezone
if req.Timezone != nil && *req.Timezone != "" {
timezone = *req.Timezone
}
// Update span attributes with final values
span.SetAttributes(
attribute.String("signup.language", language),
attribute.String("signup.level", level),
attribute.String("signup.timezone", timezone),
)
// Create user with email and timezone (no AI settings)
user, err := h.userService.CreateUserWithEmailAndTimezone(c.Request.Context(), req.Username, email, timezone, language, level)
if err != nil {
h.logger.Error(c.Request.Context(), "Error creating user", err, map[string]interface{}{"username": req.Username, "email": email})
HandleAppError(c, contextutils.WrapError(err, "failed to create user account"))
return
}
// Now set the password hash
if err := h.userService.UpdateUserPassword(c.Request.Context(), user.ID, req.Password); err != nil {
h.logger.Error(c.Request.Context(), "Error setting user password", err, map[string]interface{}{"user_id": user.ID})
// Try to clean up the user we just created
if deleteErr := h.userService.DeleteUser(c.Request.Context(), user.ID); deleteErr != nil {
h.logger.Error(c.Request.Context(), "Error cleaning up user after password set failure", err, map[string]interface{}{"user_id": user.ID, "error": deleteErr.Error()})
}
HandleAppError(c, contextutils.WrapError(err, "failed to create user account"))
return
}
// Update span attributes with created user info
span.SetAttributes(
attribute.Int("user.id", user.ID),
attribute.String("user.username", user.Username),
attribute.String("user.email", email),
)
h.logger.Info(c.Request.Context(), "Successfully created user", map[string]interface{}{"username": req.Username, "user_id": user.ID})
// Return success response (no session created, no auto-login)
c.JSON(http.StatusCreated, api.SuccessResponse{
Success: true,
Message: stringPtr("Account created successfully. Please log in."),
})
}
// GoogleLogin initiates Google OAuth flow
func (h *AuthHandler) GoogleLogin(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "google_login")
defer observability.FinishSpan(span, nil)
// Generate a state parameter for security
state := generateRandomState()
// Get the redirect URI from query parameters
redirectURI := c.Query("redirect_uri")
// Set span attributes
span.SetAttributes(
attribute.String("oauth.provider", "google"),
attribute.String("oauth.state", state),
attribute.String("oauth.redirect_uri", redirectURI),
)
// Store state and redirect URI in session for verification
session := sessions.Default(c)
session.Set("oauth_state", state)
if redirectURI != "" {
session.Set("oauth_redirect_uri", redirectURI)
}
if err := session.Save(); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to save session"))
return
}
// Generate Google OAuth URL
authURL := h.oauthService.GetGoogleAuthURL(c.Request.Context(), state)
c.JSON(http.StatusOK, gin.H{
"auth_url": authURL,
})
}
// GoogleCallback handles the OAuth callback from Google
func (h *AuthHandler) GoogleCallback(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "google_callback")
defer observability.FinishSpan(span, nil)
// Get the authorization code and state from query parameters
code := c.Query("code")
state := c.Query("state")
// Set span attributes
span.SetAttributes(
attribute.String("oauth.provider", "google"),
attribute.Bool("oauth.code_provided", code != ""),
attribute.String("oauth.state", state),
)
h.logger.Info(c.Request.Context(), "Google OAuth callback received", map[string]interface{}{"code": code, "state": state})
if code == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Verify state parameter for OAuth security (CSRF protection)
session := sessions.Default(c)
storedState := session.Get("oauth_state")
h.logger.Info(c.Request.Context(), "OAuth state verification", map[string]interface{}{"stored_state": storedState, "received_state": state})
// Enforce strict state verification for security
if storedState == nil {
h.logger.Error(c.Request.Context(), "No OAuth state found in session - possible CSRF attack or session issue", nil, map[string]interface{}{"state": state})
span.SetAttributes(attribute.Bool("oauth.state_valid", false))
HandleAppError(c, contextutils.ErrOAuthStateMismatch)
return
}
if storedState.(string) != state {
h.logger.Error(c.Request.Context(), "OAuth state mismatch - possible CSRF attack", nil, map[string]interface{}{"stored_state": storedState.(string), "received_state": state})
span.SetAttributes(attribute.Bool("oauth.state_valid", false))
HandleAppError(c, contextutils.ErrOAuthStateMismatch)
return
}
span.SetAttributes(attribute.Bool("oauth.state_valid", true))
h.logger.Info(c.Request.Context(), "OAuth state verification successful")
// Check if user is already authenticated (prevent duplicate callbacks)
existingUserID := session.Get(middleware.UserIDKey)
if existingUserID != nil {
h.logger.Info(c.Request.Context(), "User already authenticated during OAuth callback", map[string]interface{}{
"user_id": existingUserID.(int),
})
span.SetAttributes(attribute.Bool("oauth.duplicate_callback", true))
// Get user information for the response
user, err := h.userService.GetUserByID(c.Request.Context(), existingUserID.(int))
if err != nil {
h.logger.Error(c.Request.Context(), "Error getting user by ID", err, map[string]interface{}{"user_id": existingUserID.(int)})
HandleAppError(c, contextutils.ErrInternalError)
return
}
if user == nil {
h.logger.Error(c.Request.Context(), "User not found", nil, map[string]interface{}{"user_id": existingUserID.(int)})
HandleAppError(c, contextutils.ErrInternalError)
return
}
// Convert models.User to api.User with proper API key checking
apiUser := convertUserToAPIWithService(c.Request.Context(), user, h.userService)
// Return success response for already authenticated user
response := api.LoginResponse{
Success: boolPtr(true),
Message: stringPtr("Already authenticated"),
User: &apiUser,
}
c.JSON(http.StatusOK, response)
return
}
// Get the stored redirect URI from session
storedRedirectURI := session.Get("oauth_redirect_uri")
var redirectURI string
if storedRedirectURI != nil {
redirectURI = storedRedirectURI.(string)
}
// Clear the state and redirect URI from session
session.Delete("oauth_state")
session.Delete("oauth_redirect_uri")
if err := session.Save(); err != nil {
h.logger.Error(c.Request.Context(), "Failed to save session", err, map[string]interface{}{"error": err.Error()})
HandleAppError(c, contextutils.WrapError(err, "failed to save session"))
return
}
// Authenticate user with Google OAuth
user, err := h.oauthService.AuthenticateGoogleUser(c.Request.Context(), code, h.userService)
if err != nil {
h.logger.Error(c.Request.Context(), "Google OAuth authentication failed", err, map[string]interface{}{"error": err.Error()})
// Check if this is a signup disabled error (structured)
if errors.Is(err, services.ErrSignupsDisabled) {
span.SetAttributes(attribute.Bool("oauth.signups_disabled", true))
HandleAppError(c, contextutils.ErrForbidden)
return
}
// Provide better error messages to the frontend using structured error checking
errorMessage := "Authentication failed"
if errors.Is(err, services.ErrOAuthCodeAlreadyUsed) {
errorMessage = "This authentication link has already been used. Please try signing in again."
} else if errors.Is(err, services.ErrOAuthClientConfig) {
errorMessage = "OAuth configuration error. Please contact support."
} else if errors.Is(err, services.ErrOAuthInvalidRequest) {
errorMessage = "Invalid authentication request. Please try again."
} else if errors.Is(err, services.ErrOAuthUnauthorized) {
errorMessage = "OAuth client is not authorized. Please contact support."
} else if errors.Is(err, services.ErrOAuthUnsupportedGrant) {
errorMessage = "Unsupported OAuth grant type. Please contact support."
}
HandleAppError(c, contextutils.WrapError(err, errorMessage))
return
}
// Update span attributes with user info
span.SetAttributes(
attribute.Int("user.id", user.ID),
attribute.String("user.username", user.Username),
attribute.Bool("user.email_provided", user.Email.Valid),
attribute.String("user.language", user.PreferredLanguage.String),
attribute.String("user.level", user.CurrentLevel.String),
attribute.Bool("user.is_new", user.CreatedAt.After(time.Now().Add(-5*time.Minute))), // Rough check if user was just created
)
// Update last active
if err := h.userService.UpdateLastActive(c.Request.Context(), user.ID); err != nil {
h.logger.Warn(c.Request.Context(), "Failed to update last active for user", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
}
// Create session
session.Set(middleware.UserIDKey, user.ID)
session.Set(middleware.UsernameKey, user.Username)
h.logger.Info(c.Request.Context(), "Setting session for user", map[string]interface{}{"user_id": user.ID, "username": user.Username})
if err := session.Save(); err != nil {
h.logger.Error(c.Request.Context(), "Failed to save session", err, map[string]interface{}{"error": err.Error()})
HandleAppError(c, contextutils.WrapError(err, "failed to create session"))
return
}
// Convert models.User to api.User with proper API key checking
apiUser := convertUserToAPIWithService(c.Request.Context(), user, h.userService)
h.logger.Info(c.Request.Context(), "Google OAuth successful for user", map[string]interface{}{"username": user.Username, "user_id": user.ID})
// Return user info with redirect URI if available
response := api.LoginResponse{
Success: boolPtr(true),
Message: stringPtr("Google authentication successful"),
User: &apiUser,
}
// Add redirect URI to response if it was stored
if redirectURI != "" {
response.RedirectUri = &redirectURI
}
c.JSON(http.StatusOK, response)
}
// generateRandomState generates a cryptographically secure random state parameter for OAuth security
func generateRandomState() string {
const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
b := make([]byte, 32)
// Use crypto/rand for cryptographically secure random generation
for i := range b {
// Generate a random byte and map it to charset
randomByte := make([]byte, 1)
if _, err := rand.Read(randomByte); err != nil {
// If crypto/rand fails, we have a serious system issue - don't fallback to weaker randomness
panic("Cryptographic random number generation failed: " + err.Error())
}
b[i] = charset[randomByte[0]%byte(len(charset))]
}
return string(b)
}
// SignupStatus returns whether signups are enabled or disabled
func (h *AuthHandler) SignupStatus(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "signup_status")
defer observability.FinishSpan(span, nil)
signupsDisabled := false
oauthWhitelistEnabled := false
var allowedDomains []string
var allowedEmails []string
if h.config != nil {
signupsDisabled = h.config.IsSignupDisabled()
if h.config.System != nil {
oauthWhitelistEnabled = len(h.config.System.Auth.AllowedDomains) > 0 || len(h.config.System.Auth.AllowedEmails) > 0
allowedDomains = h.config.System.Auth.AllowedDomains
allowedEmails = h.config.System.Auth.AllowedEmails
}
}
span.SetAttributes(
attribute.Bool("auth.signups_disabled", signupsDisabled),
attribute.Bool("auth.config_available", h.config != nil),
attribute.Bool("auth.oauth_whitelist_enabled", oauthWhitelistEnabled),
)
c.JSON(http.StatusOK, gin.H{
"signups_disabled": signupsDisabled,
"oauth_whitelist_enabled": oauthWhitelistEnabled,
"allowed_domains": allowedDomains,
"allowed_emails": allowedEmails,
})
}
package handlers
import (
"context"
"errors"
"quizapp/internal/middleware"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
var (
// ErrUnauthenticated indicates no current user could be determined
ErrUnauthenticated = errors.New("user not authenticated")
// ErrInvalidUserID indicates the stored user identifier is malformed
ErrInvalidUserID = errors.New("invalid user id")
// ErrForbidden indicates the user lacks permissions for the operation
ErrForbidden = errors.New("forbidden")
)
// GetCurrentUserID returns the current authenticated user's ID.
// It first checks the Gin context (set by RequireAuth/RequireAdmin),
// then falls back to the session store. Returns an error if unauthenticated
// or if the stored value is invalid.
func GetCurrentUserID(c *gin.Context) (int, error) {
if rawID, exists := c.Get(middleware.UserIDKey); exists {
if id, ok := rawID.(int); ok {
return id, nil
}
return 0, ErrInvalidUserID
}
// Fallback to session lookup if context not populated
session := sessions.Default(c)
userID := session.Get(middleware.UserIDKey)
if userID == nil {
return 0, ErrUnauthenticated
}
id, ok := userID.(int)
if !ok {
return 0, ErrInvalidUserID
}
return id, nil
}
// authzAdminChecker is the minimal capability needed from user service for admin checks.
// Any concrete user service that implements IsAdmin satisfies this interface.
type authzAdminChecker interface {
IsAdmin(ctx context.Context, userID int) (bool, error)
}
// RequireSelfOrAdmin permits the action if the current user is the target user
// or has admin privileges. Returns ErrForbidden when neither condition is met.
func RequireSelfOrAdmin(ctx context.Context, svc authzAdminChecker, currentID, targetID int) error {
if currentID == 0 {
return ErrUnauthenticated
}
if currentID == targetID {
return nil
}
isAdmin, err := svc.IsAdmin(ctx, currentID)
if err != nil {
return err
}
if !isAdmin {
return ErrForbidden
}
return nil
}
package handlers
import (
"context"
"encoding/json"
"time"
"quizapp/internal/api"
"quizapp/internal/models"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
openapi_types "github.com/oapi-codegen/runtime/types"
)
// Helper functions for pointer conversion
func stringPtr(s string) *string {
return &s
}
func boolPtr(b bool) *bool {
return &b
}
func int64Ptr(i int) *int64 {
i64 := int64(i)
return &i64
}
func float32Ptr(f float32) *float32 {
return &f
}
func intPtr(i int) *int {
return &i
}
// formatTimePtr formats a time.Time into an RFC3339 string pointer
func formatTimePtr(t time.Time) *string {
s := t.In(time.UTC).Format(time.RFC3339)
return &s
}
// formatTimePointer converts a *time.Time to *string (RFC3339) or nil
func formatTimePointer(tp *time.Time) *string {
if tp == nil {
return nil
}
s := tp.In(time.UTC).Format(time.RFC3339)
return &s
}
// formatTime formats a time.Time into an RFC3339 string
func formatTime(t time.Time) string {
return t.In(time.UTC).Format(time.RFC3339)
}
// Convert models.User to api.User
func convertUserToAPI(user *models.User) api.User {
apiUser := api.User{
Id: int64Ptr(user.ID),
Username: stringPtr(user.Username),
}
if !user.CreatedAt.IsZero() {
apiUser.CreatedAt = formatTimePtr(user.CreatedAt)
}
if user.LastActive.Valid {
apiUser.LastActive = formatTimePointer(&user.LastActive.Time)
}
if user.Email.Valid {
s := user.Email.String
apiUser.Email = &s
}
if user.Timezone.Valid {
s := user.Timezone.String
apiUser.Timezone = &s
}
if user.PreferredLanguage.Valid {
s := user.PreferredLanguage.String
apiUser.PreferredLanguage = &s
}
if user.CurrentLevel.Valid {
s := user.CurrentLevel.String
apiUser.CurrentLevel = &s
}
if user.AIProvider.Valid {
s := user.AIProvider.String
apiUser.AiProvider = &s
}
if user.AIModel.Valid {
s := user.AIModel.String
apiUser.AiModel = &s
}
// Always set ai_enabled as a boolean (never null)
aiEnabled := user.AIEnabled.Valid && user.AIEnabled.Bool
apiUser.AiEnabled = &aiEnabled
// For backwards compatibility, we'll set has_api_key to false here
// The proper check should be done using convertUserToAPIWithService
hasAPIKey := false
apiUser.HasApiKey = &hasAPIKey
// Include user roles if they exist
if len(user.Roles) > 0 {
apiRoles := make([]api.Role, len(user.Roles))
for i, role := range user.Roles {
apiRoles[i] = api.Role{
Id: int64(role.ID),
Name: role.Name,
Description: role.Description,
CreatedAt: formatTime(role.CreatedAt),
UpdatedAt: formatTime(role.UpdatedAt),
}
}
apiUser.Roles = &apiRoles
}
return apiUser
}
// convertUserToAPIWithService converts a models.User to api.User with proper API key checking
func convertUserToAPIWithService(ctx context.Context, user *models.User, userService services.UserServiceInterface) api.User {
apiUser := convertUserToAPI(user)
// Check if user has a valid API key for their current provider using the new table
hasAPIKey := false
if user.AIProvider.Valid && user.AIProvider.String != "" {
// Use the new per-provider API key system instead of the old user.AIAPIKey field
if userService != nil {
savedKey, err := userService.GetUserAPIKey(ctx, user.ID, user.AIProvider.String)
if err == nil && savedKey != "" {
// API key is available but not exposed in the API response for security
hasAPIKey = true
}
}
}
// If user doesn't have an AI provider set, hasAPIKey remains false (default)
apiUser.HasApiKey = &hasAPIKey
return apiUser
}
// Convert models.Question to api.Question
func convertQuestionToAPI(question *models.Question) api.Question {
apiQuestion := api.Question{
Id: int64Ptr(question.ID),
DifficultyScore: float32Ptr(float32(question.DifficultyScore)),
CorrectAnswer: intPtr(question.CorrectAnswer),
// UsageCount removed; use total_responses instead
}
if !question.CreatedAt.IsZero() {
v := formatTime(question.CreatedAt)
apiQuestion.CreatedAt = &v
}
if question.Type != "" {
qType := api.QuestionType(question.Type)
apiQuestion.Type = &qType
}
if question.Language != "" {
lang := api.Language(question.Language)
apiQuestion.Language = &lang
}
if question.Level != "" {
level := api.Level(question.Level)
apiQuestion.Level = &level
}
if question.Explanation != "" {
apiQuestion.Explanation = &question.Explanation
}
if question.Status != "" {
status := api.QuestionStatus(question.Status)
apiQuestion.Status = &status
}
// Convert content map to api.QuestionContent
if question.Content != nil {
content := &api.QuestionContent{}
if q, ok := question.Content["question"].(string); ok {
content.Question = q
}
if hint, ok := question.Content["hint"].(string); ok {
content.Hint = &hint
}
if passage, ok := question.Content["passage"].(string); ok {
content.Passage = &passage
}
if sentence, ok := question.Content["sentence"].(string); ok {
content.Sentence = &sentence
}
if opts, ok := question.Content["options"].([]interface{}); ok {
var options []string
for _, opt := range opts {
if o, ok := opt.(string); ok {
options = append(options, o)
}
}
if len(options) > 0 {
content.Options = options
}
}
apiQuestion.Content = content
}
// Add variety elements to the API response
if question.TopicCategory != "" {
apiQuestion.TopicCategory = &question.TopicCategory
}
if question.GrammarFocus != "" {
apiQuestion.GrammarFocus = &question.GrammarFocus
}
if question.VocabularyDomain != "" {
apiQuestion.VocabularyDomain = &question.VocabularyDomain
}
if question.Scenario != "" {
apiQuestion.Scenario = &question.Scenario
}
if question.StyleModifier != "" {
apiQuestion.StyleModifier = &question.StyleModifier
}
if question.DifficultyModifier != "" {
apiQuestion.DifficultyModifier = &question.DifficultyModifier
}
if question.TimeContext != "" {
apiQuestion.TimeContext = &question.TimeContext
}
return apiQuestion
}
// Convert services.QuestionWithStats to a JSON-compatible map using generated
// api.Question for fields, and include any additional fields the frontend
// expects (e.g., report_reasons) that are not present on the generated type.
func convertQuestionWithStatsToAPIMap(q *services.QuestionWithStats) map[string]interface{} {
apiQ := api.Question{}
if q != nil && q.Question != nil {
apiQ = convertQuestionToAPI(q.Question)
}
// Attach stats
if q != nil {
apiQ.CorrectCount = intPtr(q.CorrectCount)
apiQ.IncorrectCount = intPtr(q.IncorrectCount)
apiQ.TotalResponses = intPtr(q.TotalResponses)
apiQ.UserCount = intPtr(q.UserCount)
if q.Reporters != "" {
apiQ.Reporters = &q.Reporters
}
// ConfidenceLevel is optional
if q.ConfidenceLevel != nil {
apiQ.ConfidenceLevel = q.ConfidenceLevel
}
}
// Marshal to generic map so we can add fields not present in api.Question
m := map[string]interface{}{}
if b, err := json.Marshal(apiQ); err == nil {
_ = json.Unmarshal(b, &m)
}
// Add report_reasons if available on the service struct
if q != nil && q.ReportReasons != "" {
m["report_reasons"] = q.ReportReasons
}
return m
}
// Convert models.UserProgress to api.UserProgress
func convertUserProgressToAPI(ctx context.Context, progress *models.UserProgress, userID int, userLookup func(context.Context, int) (*models.User, error)) api.UserProgress {
apiProgress := api.UserProgress{
TotalQuestions: intPtr(progress.TotalQuestions),
CorrectAnswers: intPtr(progress.CorrectAnswers),
AccuracyRate: float32Ptr(float32(progress.AccuracyRate / 100.0)),
}
if progress.CurrentLevel != "" {
level := api.Level(progress.CurrentLevel)
apiProgress.CurrentLevel = &level
}
if progress.SuggestedLevel != "" {
level := api.Level(progress.SuggestedLevel)
apiProgress.SuggestedLevel = &level
}
if progress.WeakAreas != nil {
apiProgress.WeakAreas = &progress.WeakAreas
}
// Convert performance metrics
if progress.PerformanceByTopic != nil {
perfMap := make(map[string]api.PerformanceMetrics)
for topic, metrics := range progress.PerformanceByTopic {
if metrics != nil {
perfMap[topic] = api.PerformanceMetrics{
TotalAttempts: intPtr(metrics.TotalAttempts),
CorrectAttempts: intPtr(metrics.CorrectAttempts),
AverageResponseTimeMs: float32Ptr(float32(metrics.AverageResponseTimeMs)),
LastUpdated: func() *string {
if metrics.LastUpdated.IsZero() {
return nil
}
s, _, err := contextutils.FormatTimeInUserTimezone(ctx, userID, metrics.LastUpdated, time.RFC3339, userLookup)
if err != nil || s == "" {
tmp := metrics.LastUpdated.In(time.UTC).Format(time.RFC3339)
return &tmp
}
return &s
}(),
}
}
}
apiProgress.PerformanceByTopic = &perfMap
}
// Convert recent activity
if progress.RecentActivity != nil {
var recentActivity []api.UserResponse
for _, activity := range progress.RecentActivity {
apiActivity := api.UserResponse{
QuestionId: int64Ptr(activity.QuestionID),
IsCorrect: &activity.IsCorrect,
}
if !activity.CreatedAt.IsZero() {
s, _, err := contextutils.FormatTimeInUserTimezone(ctx, userID, activity.CreatedAt, time.RFC3339, userLookup)
if err != nil || s == "" {
tmp := activity.CreatedAt.In(time.UTC).Format(time.RFC3339)
apiActivity.CreatedAt = &tmp
} else {
apiActivity.CreatedAt = &s
}
}
recentActivity = append(recentActivity, apiActivity)
}
apiProgress.RecentActivity = &recentActivity
}
return apiProgress
}
// Convert models.DailyQuestionAssignmentWithQuestion to api.DailyQuestionWithDetails
func convertDailyAssignmentToAPI(ctx context.Context, assignment *models.DailyQuestionAssignmentWithQuestion, userID int, userLookup func(context.Context, int) (*models.User, error)) api.DailyQuestionWithDetails {
var completedAt *string
if assignment.CompletedAt.Valid {
if s, _, err := contextutils.FormatTimeInUserTimezone(ctx, userID, assignment.CompletedAt.Time, time.RFC3339, userLookup); err == nil && s != "" {
completedAt = &s
} else {
tmp := assignment.CompletedAt.Time.In(time.UTC).Format(time.RFC3339)
completedAt = &tmp
}
}
apiQuestion := api.Question{}
if assignment.Question != nil {
apiQuestion = convertQuestionToAPI(assignment.Question)
// Override total_responses so UI 'Shown' reflects Daily-only impressions
if assignment.DailyShownCount > 0 {
apiQuestion.TotalResponses = &assignment.DailyShownCount
}
}
// AssignmentDate: produce date-only value (YYYY-MM-DD) using openapi_types.Date
ad := assignment.AssignmentDate
assignDate := openapi_types.Date{Time: ad}
// CreatedAt in user's timezone (with error-checked fallback)
var createdStr string
if s, _, err := contextutils.FormatTimeInUserTimezone(ctx, userID, assignment.CreatedAt, time.RFC3339, userLookup); err == nil && s != "" {
createdStr = s
} else {
createdStr = assignment.CreatedAt.In(time.UTC).Format(time.RFC3339)
}
var submittedAt *string
if assignment.SubmittedAt != nil {
if s, _, err := contextutils.FormatTimeInUserTimezone(ctx, userID, *assignment.SubmittedAt, time.RFC3339, userLookup); err == nil && s != "" {
submittedAt = &s
} else {
tmp := assignment.SubmittedAt.In(time.UTC).Format(time.RFC3339)
submittedAt = &tmp
}
}
result := api.DailyQuestionWithDetails{
Id: int64(assignment.ID),
UserId: int64(assignment.UserID),
QuestionId: int64(assignment.QuestionID),
AssignmentDate: assignDate,
IsCompleted: assignment.IsCompleted,
CompletedAt: completedAt,
CreatedAt: createdStr,
UserAnswerIndex: assignment.UserAnswerIndex,
SubmittedAt: submittedAt,
Question: apiQuestion,
}
// Attach per-user stats when available
if assignment.DailyShownCount >= 0 {
shown := int64(assignment.DailyShownCount)
result.UserShownCount = &shown
}
if assignment.UserTotalResponses >= 0 {
total := int64(assignment.UserTotalResponses)
result.UserTotalResponses = &total
}
if assignment.UserCorrectCount >= 0 {
cc := int64(assignment.UserCorrectCount)
result.UserCorrectCount = &cc
}
if assignment.UserIncorrectCount >= 0 {
ic := int64(assignment.UserIncorrectCount)
result.UserIncorrectCount = &ic
}
return result
}
// Convert slice of assignments
func convertDailyAssignmentsToAPI(ctx context.Context, assignments []*models.DailyQuestionAssignmentWithQuestion, userID int, userLookup func(context.Context, int) (*models.User, error)) []api.DailyQuestionWithDetails {
if len(assignments) == 0 {
return []api.DailyQuestionWithDetails{}
}
apiAssignments := make([]api.DailyQuestionWithDetails, len(assignments))
for i, a := range assignments {
apiAssignments[i] = convertDailyAssignmentToAPI(ctx, a, userID, userLookup)
}
return apiAssignments
}
// Convert models.DailyProgress to api.DailyProgress
func convertDailyProgressToAPI(progress *models.DailyProgress) api.DailyProgress {
return api.DailyProgress{
Date: openapi_types.Date{Time: progress.Date},
Completed: progress.Completed,
Total: progress.Total,
}
}
package handlers
import (
"context"
"net/http"
"strconv"
"strings"
"time"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// DailyQuestionHandler handles daily question-related HTTP requests
type DailyQuestionHandler struct {
userService services.UserServiceInterface
dailyQuestionService services.DailyQuestionServiceInterface
cfg *config.Config
logger *observability.Logger
}
// NewDailyQuestionHandler creates a new DailyQuestionHandler
func NewDailyQuestionHandler(
userService services.UserServiceInterface,
dailyQuestionService services.DailyQuestionServiceInterface,
cfg *config.Config,
logger *observability.Logger,
) *DailyQuestionHandler {
return &DailyQuestionHandler{
userService: userService,
dailyQuestionService: dailyQuestionService,
cfg: cfg,
logger: logger,
}
}
// ParseDateInUserTimezone parses a date string in the user's timezone
func (h *DailyQuestionHandler) ParseDateInUserTimezone(ctx context.Context, userID int, dateStr string) (time.Time, string, error) {
// Delegate to shared util with injected user lookup
return contextutils.ParseDateInUserTimezone(ctx, userID, dateStr, h.userService.GetUserByID)
}
// GetDailyQuestions handles GET /v1/daily/questions/{date}
func (h *DailyQuestionHandler) GetDailyQuestions(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_daily_questions")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse date parameter
dateStr := c.Param("date")
if dateStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Parse date in user's timezone
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, dateStr)
if err != nil {
// Check if it's an invalid date format error
if strings.Contains(err.Error(), "invalid date format") {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
attribute.String("timezone", timezone),
)
// Get daily questions for the date
questions, err := h.dailyQuestionService.GetDailyQuestions(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to get daily questions", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
"timezone": timezone,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get daily questions"))
return
}
// Convert to API types using shared converter
apiQuestions := convertDailyAssignmentsToAPI(ctx, questions, userID, h.userService.GetUserByID)
c.JSON(http.StatusOK, gin.H{
"questions": apiQuestions,
"date": dateStr,
})
}
// MarkQuestionCompleted handles POST /v1/daily/questions/{date}/complete/{questionId}
func (h *DailyQuestionHandler) MarkQuestionCompleted(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "mark_daily_question_completed")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse parameters
dateStr := c.Param("date")
questionIDStr := c.Param("questionId")
if dateStr == "" || questionIDStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Parse date in user's timezone
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, dateStr)
if err != nil {
// Check if it's an invalid date format error
if strings.Contains(err.Error(), "invalid date format") {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
attribute.Int("question_id", questionID),
attribute.String("timezone", timezone),
)
// Mark question as completed
err = h.dailyQuestionService.MarkQuestionCompleted(ctx, userID, questionID, date)
if err != nil {
h.logger.Error(ctx, "Failed to mark daily question as completed", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": dateStr,
"timezone": timezone,
})
// Check if the error indicates no assignment was found
if contextutils.IsError(err, contextutils.ErrAssignmentNotFound) {
HandleAppError(c, contextutils.ErrAssignmentNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to mark question as completed"))
return
}
c.JSON(http.StatusOK, api.SuccessResponse{
Message: stringPtr("Question marked as completed"),
})
}
// ResetQuestionCompleted handles DELETE /v1/daily/questions/{date}/complete/{questionId}
func (h *DailyQuestionHandler) ResetQuestionCompleted(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "reset_daily_question_completed")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse parameters
dateStr := c.Param("date")
questionIDStr := c.Param("questionId")
if dateStr == "" || questionIDStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Parse date in user's timezone
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, dateStr)
if err != nil {
// Check if it's an invalid date format error
if strings.Contains(err.Error(), "invalid date format") {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
attribute.Int("question_id", questionID),
attribute.String("timezone", timezone),
)
// Reset question completion status
err = h.dailyQuestionService.ResetQuestionCompleted(ctx, userID, questionID, date)
if err != nil {
h.logger.Error(ctx, "Failed to reset daily question completion", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": dateStr,
"timezone": timezone,
})
// Check if the error indicates no assignment was found
if contextutils.IsError(err, contextutils.ErrAssignmentNotFound) {
HandleAppError(c, contextutils.ErrAssignmentNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to reset question completion"))
return
}
c.JSON(http.StatusOK, api.SuccessResponse{
Message: stringPtr("Question completion reset"),
})
}
// GetAvailableDates handles GET /v1/daily/dates
func (h *DailyQuestionHandler) GetAvailableDates(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_daily_available_dates")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Add span attributes for observability
span.SetAttributes(observability.AttributeUserID(userID))
// Get available dates with assignments
dates, err := h.dailyQuestionService.GetAvailableDates(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get available dates", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get available dates"))
return
}
// Exclude future dates based on the user's timezone (clients expect local calendar days only)
user, _ := h.userService.GetUserByID(ctx, userID)
tz := "UTC"
if user != nil && user.Timezone.Valid && user.Timezone.String != "" {
tz = user.Timezone.String
}
loc, err := time.LoadLocation(tz)
if err != nil {
loc = time.UTC
}
now := time.Now().In(loc)
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc)
// Filter out dates that are after today in the user's timezone
var filtered []time.Time
for _, d := range dates {
// Treat the date value as a date-only value (time component ignored)
if !d.After(today) {
filtered = append(filtered, d)
}
}
// Convert dates to string format for JSON response
dateStrings := make([]string, len(filtered))
for i, date := range filtered {
dateStrings[i] = date.Format("2006-01-02")
}
c.JSON(http.StatusOK, gin.H{
"dates": dateStrings,
})
}
// Note: Daily question assignment is now handled automatically by the worker
// when sending daily reminder emails. No manual assignment endpoint needed.
// GetDailyProgress handles GET /v1/daily/progress/{date}
func (h *DailyQuestionHandler) GetDailyProgress(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_daily_progress")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse date parameter
dateStr := c.Param("date")
if dateStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Parse date in user's timezone
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, dateStr)
if err != nil {
// Check if it's an invalid date format error
if strings.Contains(err.Error(), "invalid date format") {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
attribute.String("timezone", timezone),
)
// Get daily progress for the date
progress, err := h.dailyQuestionService.GetDailyProgress(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to get daily progress", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
"timezone": timezone,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get daily progress"))
return
}
// Convert to API type using shared converter
apiProgress := convertDailyProgressToAPI(progress)
c.JSON(http.StatusOK, apiProgress)
}
// SubmitDailyQuestionAnswer handles POST /v1/daily/questions/{date}/answer/{questionId}
func (h *DailyQuestionHandler) SubmitDailyQuestionAnswer(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "submit_daily_question_answer")
defer observability.FinishSpan(span, nil)
h.logger.Info(ctx, "SubmitDailyQuestionAnswer handler called", map[string]interface{}{
"method": c.Request.Method,
"path": c.Request.URL.Path,
"params": c.Params,
})
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse parameters
dateStr := c.Param("date")
questionIDStr := c.Param("questionId")
if dateStr == "" || questionIDStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Parse date in user's timezone
date, timezone, err := h.ParseDateInUserTimezone(ctx, userID, dateStr)
if err != nil {
// Check if it's an invalid date format error
if strings.Contains(err.Error(), "invalid date format") {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Parse request body
var requestBody api.PostV1DailyQuestionsDateAnswerQuestionIdJSONBody
h.logger.Info(ctx, "Parsing request body", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": dateStr,
"timezone": timezone,
})
if err := c.ShouldBindJSON(&requestBody); err != nil {
h.logger.Error(ctx, "Failed to parse request body", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": dateStr,
"timezone": timezone,
"error": err.Error(),
})
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
h.logger.Info(ctx, "Request body parsed successfully",
map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": dateStr,
"timezone": timezone,
"user_answer_index": requestBody.UserAnswerIndex,
})
// Validate user answer index
if requestBody.UserAnswerIndex < 0 {
h.logger.Warn(ctx, "Invalid user answer index in SubmitDailyQuestionAnswer", map[string]interface{}{"user_answer_index": requestBody.UserAnswerIndex})
HandleAppError(c, contextutils.ErrInvalidAnswerIndex)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
attribute.Int("question_id", questionID),
attribute.String("timezone", timezone),
attribute.Int("user_answer_index", requestBody.UserAnswerIndex),
)
// Submit the answer
response, err := h.dailyQuestionService.SubmitDailyQuestionAnswer(
ctx,
userID,
questionID,
date,
requestBody.UserAnswerIndex,
)
if err != nil {
h.logger.Error(ctx, "Failed to submit daily question answer", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": dateStr,
"timezone": timezone,
"user_answer_index": requestBody.UserAnswerIndex,
})
// Check for specific error types
if contextutils.IsError(err, contextutils.ErrQuestionAlreadyAnswered) {
HandleAppError(c, contextutils.ErrQuestionAlreadyAnswered)
return
}
if contextutils.IsError(err, contextutils.ErrAssignmentNotFound) {
HandleAppError(c, contextutils.ErrAssignmentNotFound)
return
}
if contextutils.IsError(err, contextutils.ErrInvalidAnswerIndex) {
HandleAppError(c, contextutils.ErrInvalidAnswerIndex)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to submit answer"))
return
}
// Add completion status to response
responseWithCompletion := gin.H{
"user_answer_index": response.UserAnswerIndex,
"user_answer": response.UserAnswer,
"is_correct": response.IsCorrect,
"correct_answer_index": response.CorrectAnswerIndex,
"explanation": response.Explanation,
"is_completed": true,
}
c.JSON(http.StatusOK, responseWithCompletion)
}
// GetQuestionHistory handles GET /v1/daily/questions/{questionId}/history
func (h *DailyQuestionHandler) GetQuestionHistory(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_question_history")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Parse question ID parameter
questionIDStr := c.Param("questionId")
if questionIDStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.Int("question_id", questionID),
)
// Get question history for the last 14 days
history, err := h.dailyQuestionService.GetQuestionHistory(ctx, userID, questionID, 14)
if err != nil {
h.logger.Error(ctx, "Failed to get question history", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get question history"))
return
}
// Determine user's timezone/location once, then filter out any future-dated assignments
user, _ := h.userService.GetUserByID(ctx, userID)
tz := "UTC"
if user != nil && user.Timezone.Valid && user.Timezone.String != "" {
tz = user.Timezone.String
}
loc, locErr := time.LoadLocation(tz)
if locErr != nil {
loc = time.UTC
}
now := time.Now().In(loc)
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc)
// Format times in user's timezone using helper, skipping future dates
resp := make([]map[string]interface{}, 0, len(history))
for _, he := range history {
// Skip future assignments in user's local date
ad := he.AssignmentDate.In(loc)
adDate := time.Date(ad.Year(), ad.Month(), ad.Day(), 0, 0, 0, 0, loc)
if adDate.After(today) {
continue
}
// Return assignment_date as date-only string (YYYY-MM-DD) using the stored UTC
// date to avoid timezone ambiguity for clients.
assignDateStr := he.AssignmentDate.UTC().Format("2006-01-02")
span.SetAttributes(attribute.String("assignment_date.formatted_with", "date_only"))
entry := map[string]interface{}{
"assignment_date": assignDateStr,
"is_completed": he.IsCompleted,
"is_correct": nil,
"submitted_at": nil,
}
if he.IsCorrect != nil {
entry["is_correct"] = *he.IsCorrect
}
if he.SubmittedAt != nil {
submittedStr, _, submittedErr := contextutils.FormatTimeInUserTimezone(ctx, userID, *he.SubmittedAt, time.RFC3339, h.userService.GetUserByID)
if submittedErr != nil || submittedStr == "" {
h.logger.Error(ctx, "Failed to format submitted_at in user's timezone", submittedErr, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"submitted_at_db": he.SubmittedAt,
})
span.RecordError(submittedErr, trace.WithStackTrace(true))
span.SetStatus(codes.Error, "failed to format submitted_at")
HandleAppError(c, contextutils.WrapError(submittedErr, "failed to format submitted_at"))
return
}
span.SetAttributes(attribute.String("submitted_at.formatted_with", "user_timezone"))
entry["submitted_at"] = submittedStr
}
resp = append(resp, entry)
}
c.JSON(http.StatusOK, gin.H{"history": resp})
}
package handlers
import (
"fmt"
"net/http"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
)
// StandardizeHTTPError creates consistent HTTP error responses with structured error information
func StandardizeHTTPError(c *gin.Context, statusCode int, message, details string) {
// Map HTTP status code to appropriate error code
var errorCode contextutils.ErrorCode
var severity contextutils.SeverityLevel
switch statusCode {
case http.StatusBadRequest:
errorCode = contextutils.ErrorCodeInvalidInput
severity = contextutils.SeverityWarn
case http.StatusUnauthorized:
errorCode = contextutils.ErrorCodeUnauthorized
severity = contextutils.SeverityWarn
case http.StatusForbidden:
errorCode = contextutils.ErrorCodeForbidden
severity = contextutils.SeverityWarn
case http.StatusNotFound:
errorCode = contextutils.ErrorCodeRecordNotFound
severity = contextutils.SeverityInfo
case http.StatusConflict:
errorCode = contextutils.ErrorCodeRecordExists
severity = contextutils.SeverityInfo
case http.StatusServiceUnavailable:
errorCode = contextutils.ErrorCodeServiceUnavailable
severity = contextutils.SeverityError
default:
errorCode = contextutils.ErrorCodeInternalError
severity = contextutils.SeverityError
}
// Create an AppError with appropriate code
appErr := contextutils.NewAppError(
errorCode,
severity,
message,
details,
)
// Send response with the original status code
c.JSON(statusCode, appErr.ToJSON())
}
// StandardizeAppError sends a structured error response using AppError
func StandardizeAppError(c *gin.Context, err *contextutils.AppError) {
// Map error codes to HTTP status codes
statusCode := mapErrorCodeToHTTPStatus(err.Code)
// Convert error to JSON structure
errorJSON := err.ToJSON()
// Add retryable information based on error type
errorJSON["retryable"] = contextutils.IsRetryable(err)
c.JSON(statusCode, errorJSON)
}
// HandleValidationError handles input validation errors consistently
func HandleValidationError(c *gin.Context, field string, value interface{}, reason string) {
appErr := contextutils.NewAppError(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
fmt.Sprintf("Invalid %s", field),
fmt.Sprintf("Value '%v' is invalid: %s", value, reason),
)
StandardizeAppError(c, appErr)
}
// HandleAppError handles any AppError and sends appropriate HTTP response
func HandleAppError(c *gin.Context, err error) {
if appErr, ok := err.(*contextutils.AppError); ok {
StandardizeAppError(c, appErr)
} else {
// Fallback for non-AppError types
StandardizeHTTPError(c, http.StatusInternalServerError, "Internal server error", err.Error())
}
}
// mapErrorCodeToHTTPStatus maps AppError codes to appropriate HTTP status codes
func mapErrorCodeToHTTPStatus(code contextutils.ErrorCode) int {
switch code {
// 4xx Client Errors
case contextutils.ErrorCodeInvalidInput, contextutils.ErrorCodeMissingRequired,
contextutils.ErrorCodeInvalidFormat, contextutils.ErrorCodeValidationFailed,
contextutils.ErrorCodeOAuthStateMismatch:
return http.StatusBadRequest
case contextutils.ErrorCodeUnauthorized:
return http.StatusUnauthorized
case contextutils.ErrorCodeForbidden:
return http.StatusForbidden
case contextutils.ErrorCodeRecordNotFound, contextutils.ErrorCodeQuestionNotFound,
contextutils.ErrorCodeAssignmentNotFound:
return http.StatusNotFound
case contextutils.ErrorCodeRecordExists:
return http.StatusConflict
case contextutils.ErrorCodeSessionExpired, contextutils.ErrorCodeInvalidCredentials:
return http.StatusUnauthorized
case contextutils.ErrorCodeRateLimit:
return http.StatusTooManyRequests
// 5xx Server Errors
case contextutils.ErrorCodeInternalError:
return http.StatusInternalServerError
case contextutils.ErrorCodeServiceUnavailable, contextutils.ErrorCodeDatabaseConnection,
contextutils.ErrorCodeAIProviderUnavailable:
return http.StatusServiceUnavailable
case contextutils.ErrorCodeTimeout:
return http.StatusRequestTimeout
case contextutils.ErrorCodeDatabaseQuery, contextutils.ErrorCodeDatabaseTransaction,
contextutils.ErrorCodeForeignKeyViolation, contextutils.ErrorCodeTimestampMissingTimezone,
contextutils.ErrorCodeAIRequestFailed, contextutils.ErrorCodeAIResponseInvalid,
contextutils.ErrorCodeAIConfigInvalid, contextutils.ErrorCodeOAuthProviderError:
return http.StatusInternalServerError
// Default to internal server error for unknown codes
default:
return http.StatusInternalServerError
}
}
package handlers
import (
"net/http"
"strconv"
"strings"
"github.com/gin-gonic/gin"
)
// ParsePagination parses standard pagination query params from the request.
// It enforces bounds and applies defaults when values are missing or invalid.
func ParsePagination(c *gin.Context, defaultPage, defaultSize, maxSize int) (int, int) {
pageStr := c.DefaultQuery("page", strconv.Itoa(defaultPage))
sizeStr := c.DefaultQuery("page_size", strconv.Itoa(defaultSize))
page, err := strconv.Atoi(pageStr)
if err != nil || page < 1 {
page = defaultPage
}
size, err := strconv.Atoi(sizeStr)
if err != nil || size < 1 {
size = defaultSize
}
if size > maxSize {
size = maxSize
}
return page, size
}
// ParseFilters returns a map of non-empty trimmed query params for the given keys.
func ParseFilters(c *gin.Context, keys ...string) map[string]string {
filters := make(map[string]string, len(keys))
for _, key := range keys {
if val := strings.TrimSpace(c.Query(key)); val != "" {
filters[key] = val
}
}
return filters
}
// WritePaginated standardizes paginated responses with a flexible items key, pagination block, and optional extras.
// It preserves existing API response shapes by allowing the caller to specify the items key.
func WritePaginated(c *gin.Context, itemsKey string, items, pagination any, extra gin.H) {
response := gin.H{
itemsKey: items,
"pagination": pagination,
}
for k, v := range extra {
response[k] = v
}
c.JSON(http.StatusOK, response)
}
package handlers
import (
"context"
"encoding/json"
"fmt"
"io"
"math/rand"
"net/http"
"strconv"
"strings"
"time"
"quizapp/internal/api"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"quizapp/internal/config"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
// QuizHandler handles quiz-related HTTP requests including questions and answers
type QuizHandler struct {
userService services.UserServiceInterface
questionService services.QuestionServiceInterface
aiService services.AIServiceInterface
learningService services.LearningServiceInterface
workerService services.WorkerServiceInterface
hintService services.GenerationHintServiceInterface
cfg *config.Config
logger *observability.Logger
}
// NewQuizHandler creates a new QuizHandler
func NewQuizHandler(
userService services.UserServiceInterface,
questionService services.QuestionServiceInterface,
aiService services.AIServiceInterface,
learningService services.LearningServiceInterface,
workerService services.WorkerServiceInterface,
hintService services.GenerationHintServiceInterface,
config *config.Config,
logger *observability.Logger,
) *QuizHandler {
return &QuizHandler{
userService: userService,
questionService: questionService,
aiService: aiService,
learningService: learningService,
workerService: workerService,
hintService: hintService,
cfg: config,
logger: logger,
}
}
// Deprecated: use GetUserIDFromSession in session.go
func (h *QuizHandler) getUserIDFromSession(c *gin.Context) (int, bool) {
return GetUserIDFromSession(c)
}
// GetQuestion handles requests for quiz questions
func (h *QuizHandler) GetQuestion(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_question")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Add span attributes for observability
span.SetAttributes(observability.AttributeUserID(userID))
// Check if a specific question ID is requested
questionIDStr := c.Param("id")
if questionIDStr != "" {
span.SetAttributes(attribute.String("question.id", questionIDStr))
h.getSpecificQuestion(c, userID, questionIDStr)
return
}
h.getNextQuestion(c, userID)
}
// getSpecificQuestion improves error handling with centralized utilities
func (h *QuizHandler) getSpecificQuestion(c *gin.Context, userID int, questionIDStr string) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_specific_question",
observability.AttributeUserID(userID),
attribute.String("question.id_str", questionIDStr),
)
defer observability.FinishSpan(span, nil)
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid question ID format",
"Question ID must be a valid integer",
err,
))
return
}
questionWithStats, err := h.questionService.GetQuestionWithStats(ctx, questionID)
if err != nil {
h.logger.Error(ctx, "Failed to get question with stats", err, map[string]interface{}{
"question_id": questionID,
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get question with stats"))
return
}
// Convert and hide sensitive information
apiQuestion := convertQuestionToAPI(questionWithStats.Question)
apiQuestion.Explanation = nil // Hide explanation
// Add response statistics to the API question
apiQuestion.CorrectCount = &questionWithStats.CorrectCount
apiQuestion.IncorrectCount = &questionWithStats.IncorrectCount
apiQuestion.TotalResponses = &questionWithStats.TotalResponses
// Get user-specific confidence level if available
confidenceLevel, err := h.learningService.GetUserQuestionConfidenceLevel(ctx, userID, questionID)
if err != nil {
h.logger.Warn(ctx, "Failed to get user confidence level", map[string]interface{}{
"error": err.Error(),
"question_id": questionID,
"user_id": userID,
})
// Don't fail the request, just continue without confidence level
} else if confidenceLevel != nil {
apiQuestion.ConfidenceLevel = confidenceLevel
}
c.JSON(http.StatusOK, apiQuestion)
}
// getNextQuestion improves error handling with centralized utilities
func (h *QuizHandler) getNextQuestion(c *gin.Context, userID int) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_next_question",
observability.AttributeUserID(userID),
)
defer observability.FinishSpan(span, nil)
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user by ID", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get user by ID"))
return
}
if user == nil {
span.SetAttributes(attribute.String("error.type", "user_nil"))
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Check if user has required preferences set
if !user.PreferredLanguage.Valid || user.PreferredLanguage.String == "" {
span.SetAttributes(attribute.String("error.type", "missing_language_preference"))
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeMissingRequired,
contextutils.SeverityWarn,
"Language preference not set",
"Please set your preferred language in settings",
nil,
))
return
}
if !user.CurrentLevel.Valid || user.CurrentLevel.String == "" {
span.SetAttributes(attribute.String("error.type", "missing_level_preference"))
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeMissingRequired,
contextutils.SeverityWarn,
"Level preference not set",
"Please set your current level in settings",
nil,
))
return
}
language := c.DefaultQuery("language", user.PreferredLanguage.String)
level := c.DefaultQuery("level", user.CurrentLevel.String)
// Handle question type selection based on query parameters
var qType models.QuestionType
requestedTypes := c.Query("type")
strictTypeRequested := false
if requestedTypes != "" {
strictTypeRequested = true
types := strings.Split(requestedTypes, ",")
// Use the first valid type from the list
for _, t := range types {
if t = strings.TrimSpace(t); t != "" {
qType = models.QuestionType(t)
break
}
}
} else {
// Check if we need to exclude certain types (comma-separated list)
excludeTypes := c.Query("exclude_type")
if excludeTypes != "" {
excludeList := strings.Split(excludeTypes, ",")
var excludeSet []models.QuestionType
for _, t := range excludeList {
if t = strings.TrimSpace(t); t != "" {
excludeSet = append(excludeSet, models.QuestionType(t))
}
}
qType = h.selectRandomQuestionTypeExcluding(excludeSet...)
} else {
// Default random selection
qType = h.selectRandomQuestionType()
}
}
// Add span attributes for observability
span.SetAttributes(
attribute.String("language", language),
attribute.String("level", level),
attribute.String("question.type", string(qType)),
attribute.Bool("strict.type.requested", strictTypeRequested),
)
// Get next question with fallback logic
questionWithStats, err := h.questionService.GetNextQuestion(ctx, userID, language, level, qType)
if err != nil {
h.logger.Error(ctx, "Failed to get next question", err, map[string]interface{}{
"user_id": userID,
"language": language,
"level": level,
"question_type": string(qType),
})
// Fallback: try without question type if strict type was requested
if strictTypeRequested {
h.logger.Info(ctx, "Attempting fallback without question type", map[string]interface{}{
"user_id": userID,
"language": language,
"level": level,
})
questionWithStats, err = h.questionService.GetNextQuestion(ctx, userID, language, level, "")
if err != nil {
h.logger.Error(ctx, "Fallback also failed", err, map[string]interface{}{
"user_id": userID,
"language": language,
"level": level,
})
HandleAppError(c, contextutils.ErrNoQuestionsAvailable)
return
}
} else {
HandleAppError(c, contextutils.ErrNoQuestionsAvailable)
return
}
}
// Check if we got a valid question
if questionWithStats == nil || questionWithStats.Question == nil {
h.logger.Error(ctx, "GetNextQuestion returned nil question", nil, map[string]interface{}{
"user_id": userID,
"language": language,
"level": level,
"question_type": string(qType),
})
// If the user strictly requested a type, record a generation hint with short TTL
if strictTypeRequested && h.hintService != nil && qType != "" {
// Best-effort; do not fail the request if hint upsert fails
_ = h.hintService.UpsertHint(ctx, userID, language, level, qType, 10*time.Minute)
}
c.JSON(http.StatusAccepted, api.GeneratingResponse{
Status: stringPtr("generating"),
Message: stringPtr("No questions available. Prioritizing your requested question type. Please try again shortly."),
})
return
}
// Convert to API format and hide sensitive information
apiQuestion := convertQuestionToAPI(questionWithStats.Question)
apiQuestion.Explanation = nil // Hide explanation
// Add response statistics to the API question
apiQuestion.CorrectCount = &questionWithStats.CorrectCount
apiQuestion.IncorrectCount = &questionWithStats.IncorrectCount
apiQuestion.TotalResponses = &questionWithStats.TotalResponses
// Add confidence level if available
if questionWithStats.ConfidenceLevel != nil {
apiQuestion.ConfidenceLevel = questionWithStats.ConfidenceLevel
}
c.JSON(http.StatusOK, apiQuestion)
}
// SubmitAnswer improves error handling with centralized utilities
func (h *QuizHandler) SubmitAnswer(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "submit_answer")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var req api.AnswerRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Error(ctx, "Invalid answer request format", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request format",
"",
err,
))
return
}
// Get the question
question, err := h.questionService.GetQuestionByID(ctx, int(req.QuestionId))
if err != nil {
h.logger.Error(ctx, "Failed to get question by ID", err, map[string]interface{}{
"question_id": req.QuestionId,
"user_id": userID,
})
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
// Check if answer is correct
isCorrect := int(req.UserAnswerIndex) == question.CorrectAnswer
// Record user response
responseTimeMs := 0
if req.ResponseTimeMs != nil {
responseTimeMs = int(*req.ResponseTimeMs)
}
// Use priority-aware recording to ensure priority scores are updated
// Store the user's answer index for future reference
if err := h.learningService.RecordAnswerWithPriority(ctx, userID, int(req.QuestionId), int(req.UserAnswerIndex), isCorrect, responseTimeMs); err != nil {
h.logger.Error(ctx, "Failed to record user response", err, map[string]interface{}{
"user_id": userID,
"question_id": req.QuestionId,
})
HandleAppError(c, contextutils.WrapError(err, "failed to record response"))
return
}
// Prepare response
// Get the user's answer text from the question options
userAnswerText := ""
if optionsRaw, ok := question.Content["options"]; ok {
if options, ok := optionsRaw.([]interface{}); ok {
if int(req.UserAnswerIndex) >= 0 && int(req.UserAnswerIndex) < len(options) {
if optStr, ok := options[int(req.UserAnswerIndex)].(string); ok {
userAnswerText = optStr
}
}
}
}
answerResponse := &api.AnswerResponse{
IsCorrect: &isCorrect,
UserAnswer: &userAnswerText,
UserAnswerIndex: &req.UserAnswerIndex,
Explanation: &question.Explanation,
CorrectAnswerIndex: &question.CorrectAnswer,
}
c.JSON(http.StatusOK, answerResponse)
// Add span attributes for observability
span.SetAttributes(
attribute.Int("user.id", userID),
attribute.Int("question.id", int(req.QuestionId)),
attribute.Bool("answer.is_correct", isCorrect),
attribute.Int("response.time_ms", responseTimeMs),
)
}
// GetProgress improves error handling with centralized utilities
func (h *QuizHandler) GetProgress(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_progress")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(observability.AttributeUserID(userID))
progress, err := h.learningService.GetUserProgress(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user progress", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get progress"))
return
}
// Get worker status information
workerStatus, err := h.getWorkerStatusForUser(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get worker status for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Get learning preferences
learningPrefs, err := h.learningService.GetUserLearningPreferences(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get learning preferences for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Get priority insights
priorityInsights, err := h.getPriorityInsightsForUser(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get priority insights for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Get generation focus information
generationFocus, err := h.getGenerationFocusForUser(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get generation focus for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Get high priority topics
highPriorityTopics, err := h.getHighPriorityTopicsForUser(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get high priority topics for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Get gap analysis
gapAnalysis, err := h.getGapAnalysisForUser(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get gap analysis for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Get priority distribution
priorityDistribution, err := h.getPriorityDistributionForUser(ctx, userID)
if err != nil {
h.logger.Warn(ctx, "Failed to get priority distribution for user", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
// Don't fail the entire request, just log the warning
}
// Convert models.UserProgress to api.UserProgress
apiProgress := convertUserProgressToAPI(ctx, progress, userID, h.userService.GetUserByID)
// Add worker-related information
if workerStatus != nil {
apiProgress.WorkerStatus = workerStatus
}
if learningPrefs != nil {
apiProgress.LearningPreferences = convertLearningPreferencesToAPI(learningPrefs)
}
if priorityInsights != nil {
apiProgress.PriorityInsights = priorityInsights
}
if generationFocus != nil {
apiProgress.GenerationFocus = generationFocus
}
if highPriorityTopics != nil {
apiProgress.HighPriorityTopics = &highPriorityTopics
}
if gapAnalysis != nil {
apiProgress.GapAnalysis = &gapAnalysis
}
if priorityDistribution != nil {
apiProgress.PriorityDistribution = &priorityDistribution
}
c.JSON(http.StatusOK, apiProgress)
}
// ReportQuestion improves error handling with centralized utilities
func (h *QuizHandler) ReportQuestion(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "report_question")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleValidationError(c, "question_id", questionIDStr, "must be a valid integer")
return
}
// Parse request body for report reason
var req struct {
ReportReason *string `json:"report_reason"`
}
// Bind JSON if present (optional)
if err := c.ShouldBindJSON(&req); err != nil {
// Ignore binding errors for optional request body
req.ReportReason = nil
}
// Get report reason, default to empty string if not provided
reportReason := ""
if req.ReportReason != nil {
reportReason = *req.ReportReason
}
span.SetAttributes(
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
err = h.questionService.ReportQuestion(ctx, questionID, userID, reportReason)
if err != nil {
h.logger.Error(ctx, "Failed to report question", err, map[string]interface{}{
"question_id": questionID,
"user_id": userID,
})
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to report question"))
return
}
c.JSON(http.StatusOK, api.SuccessResponse{Success: true, Message: stringPtr("Question reported successfully")})
}
// MarkQuestionAsKnown improves error handling with centralized utilities
func (h *QuizHandler) MarkQuestionAsKnown(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "mark_question_as_known")
defer observability.FinishSpan(span, nil)
userID, exists := GetUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
questionIDStr := c.Param("id")
questionID, err := strconv.Atoi(questionIDStr)
if err != nil {
HandleValidationError(c, "question_id", questionIDStr, "must be a valid integer")
return
}
// Optional: Parse confidence level from request body
var req struct {
ConfidenceLevel *int `json:"confidence_level"`
}
// Bind JSON if present (optional)
if err := c.ShouldBindJSON(&req); err != nil {
// Ignore binding errors for optional request body
req.ConfidenceLevel = nil
}
span.SetAttributes(
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
// Mark question as known with confidence level
err = h.learningService.MarkQuestionAsKnown(ctx, userID, questionID, req.ConfidenceLevel)
if err != nil {
h.logger.Error(ctx, "Failed to mark question as known for user", err, map[string]interface{}{
"question_id": questionID,
"user_id": userID,
})
if contextutils.IsError(err, contextutils.ErrQuestionNotFound) {
HandleAppError(c, contextutils.ErrQuestionNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to mark question as known"))
return
}
c.JSON(http.StatusOK, api.SuccessResponse{Success: true, Message: stringPtr("Question marked as known successfully")})
}
// ChatStream handles requests for AI-powered streaming chat about a question
func (h *QuizHandler) ChatStream(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "chat_stream")
defer observability.FinishSpan(span, nil)
userID, exists := h.getUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var req api.QuizChatRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request format",
"",
err,
))
return
}
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil || user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("ai.provider", user.AIProvider.String),
attribute.String("ai.model", user.AIModel.String),
)
// Prepare the request for the AI service
aiReq := &models.AIChatRequest{
Language: string(*req.Question.Language),
Level: string(*req.Question.Level),
QuestionType: models.QuestionType(*req.Question.Type),
UserMessage: req.UserMessage,
}
if req.Question.Content != nil {
aiReq.Question = req.Question.Content.Question
aiReq.Options = req.Question.Content.Options
if req.Question.Content.Passage != nil {
aiReq.Passage = *req.Question.Content.Passage
}
// For vocabulary questions, use the sentence field as the passage
if req.Question.Content.Sentence != nil && req.Question.Type != nil && *req.Question.Type == "vocabulary" {
aiReq.Passage = *req.Question.Content.Sentence
}
}
if req.AnswerContext != nil {
if req.AnswerContext.UserAnswer != nil {
aiReq.UserAnswer = *req.AnswerContext.UserAnswer
}
if req.AnswerContext.IsCorrect != nil {
aiReq.IsCorrect = req.AnswerContext.IsCorrect
}
}
// Include conversation history if provided
if req.ConversationHistory != nil {
aiReq.ConversationHistory = make([]models.ChatMessage, len(*req.ConversationHistory))
for i, msg := range *req.ConversationHistory {
aiReq.ConversationHistory[i] = models.ChatMessage{
Role: msg.Role,
Content: msg.Content,
}
}
}
// Create user AI configuration
userConfig := &services.UserAIConfig{
Provider: "", // will be set from user settings
Model: "", // use service default
APIKey: "",
Username: user.Username,
}
if user.AIProvider.Valid && user.AIProvider.String != "" {
userConfig.Provider = user.AIProvider.String
}
if user.AIModel.Valid && user.AIModel.String != "" {
userConfig.Model = user.AIModel.String
}
// Use the new per-provider API key system instead of the old user.AIAPIKey field
if userConfig.Provider != "" {
savedKey, err := h.userService.GetUserAPIKey(c.Request.Context(), userID, userConfig.Provider)
if err == nil && savedKey != "" {
userConfig.APIKey = savedKey
}
}
// Set up Server-Sent Events headers
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Headers", "Cache-Control")
// Create a channel for streaming chunks
chunks := make(chan string, 10)
// Use the request context to detect client disconnect
reqCtx := c.Request.Context()
// Create a timeout context, but also watch for client disconnect
timeoutCtx, cancel := context.WithTimeout(reqCtx, config.QuizStreamTimeout)
defer cancel()
// Combine both contexts - cancel if either times out or client disconnects
ctx, combinedCancel := context.WithCancel(timeoutCtx)
defer combinedCancel()
// Watch for client disconnect
go func() {
defer func() {
if r := recover(); r != nil {
h.logger.Error(ctx, "Panic in client disconnect watcher", nil, map[string]interface{}{
"panic": r,
})
}
}()
select {
case <-reqCtx.Done():
combinedCancel() // Cancel if client disconnects
case <-ctx.Done():
// Context already cancelled
}
}()
// Start the AI streaming in a goroutine
go func() {
defer func() {
if r := recover(); r != nil {
h.logger.Error(ctx, "Panic in AI streaming goroutine", nil, map[string]interface{}{
"panic": r,
})
}
close(chunks) // Close the channel when the goroutine completes
}()
if err := h.aiService.GenerateChatResponseStream(ctx, userConfig, aiReq, chunks); err != nil {
h.logger.Error(ctx, "AI chat streaming failed for user", err, map[string]interface{}{
"user_id": userID,
})
// Only send error if context is not cancelled (avoid sending to closed channel)
if ctx.Err() == nil {
select {
case chunks <- fmt.Sprintf("ERROR: %v", err):
default:
// Channel full, skip sending error
}
}
}
}()
// Stream the response chunks
c.Stream(func(w io.Writer) bool {
select {
case chunk, ok := <-chunks:
if !ok {
// Channel closed, end streaming
return false
}
// Handle error messages
if strings.HasPrefix(chunk, "ERROR: ") {
c.SSEvent("error", chunk[7:]) // Remove "ERROR: " prefix
return false
}
// Marshal the chunk to JSON to ensure newlines and special characters are preserved.
jsonChunk, err := json.Marshal(chunk)
if err != nil {
h.logger.Error(ctx, "Failed to marshal chat stream chunk to JSON", err)
return true // Continue streaming, skip this chunk
}
// Send normal content chunk in proper SSE format
if _, err := fmt.Fprintf(w, "data: %s\n\n", jsonChunk); err != nil {
h.logger.Error(ctx, "Failed to write chat stream data", err)
return false
}
c.Writer.Flush()
return true
case <-ctx.Done():
c.SSEvent("error", "Request timeout")
return false
}
})
}
// Helper methods
func (h *QuizHandler) selectRandomQuestionType() models.QuestionType {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
types := []models.QuestionType{
models.Vocabulary,
models.FillInBlank,
models.QuestionAnswer,
models.ReadingComprehension,
}
return types[rand.Intn(len(types))]
}
// selectRandomQuestionTypeExcluding returns a random question type excluding the specified types
func (h *QuizHandler) selectRandomQuestionTypeExcluding(excludeTypes ...models.QuestionType) models.QuestionType {
availableTypes := []models.QuestionType{
models.Vocabulary,
models.FillInBlank,
models.QuestionAnswer,
models.ReadingComprehension,
}
// Filter out excluded types
for _, excludeType := range excludeTypes {
for i, availableType := range availableTypes {
if availableType == excludeType {
availableTypes = append(availableTypes[:i], availableTypes[i+1:]...)
break
}
}
}
if len(availableTypes) == 0 {
return models.Vocabulary // Default fallback
}
return availableTypes[rand.Intn(len(availableTypes))]
}
// GetWorkerStatus returns worker status and error information for the current user
func (h *QuizHandler) GetWorkerStatus(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_worker_status")
defer observability.FinishSpan(span, nil)
userID, exists := h.getUserIDFromSession(c)
if !exists {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
span.SetAttributes(observability.AttributeUserID(userID))
// Get worker health information
workerHealth, err := h.workerService.GetWorkerHealth(ctx)
if err != nil {
h.logger.Error(ctx, "Failed to get worker health", err)
HandleAppError(c, contextutils.WrapError(err, "failed to get worker status"))
return
}
// Check if user is paused
userPaused, err := h.workerService.IsUserPaused(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to check user pause status", err, nil)
userPaused = false // Default to not paused if check fails
}
// Check if global pause is active
globalPaused, err := h.workerService.IsGlobalPaused(ctx)
if err != nil {
h.logger.Error(ctx, "Failed to check global pause status", err, nil)
globalPaused = false // Default to not paused if check fails
}
// Extract relevant information for the user
response := gin.H{
"has_errors": false,
"error_message": "",
"global_paused": globalPaused,
"user_paused": userPaused,
"healthy_workers": workerHealth["healthy_count"],
"total_workers": workerHealth["total_count"],
"last_error_details": "",
"worker_running": false,
}
// Check for worker errors
if workerInstances, ok := workerHealth["worker_instances"].([]map[string]interface{}); ok {
for _, instance := range workerInstances {
if lastError, hasError := instance["last_run_error"]; hasError && lastError != nil {
// Only handle string type
if errorStr, ok := lastError.(string); ok && errorStr != "" {
response["has_errors"] = true
response["error_message"] = "Worker encountered errors during question generation"
response["last_error_details"] = errorStr
break
}
}
if isRunning, ok := instance["is_running"].(bool); ok && isRunning {
response["worker_running"] = true
}
}
}
c.JSON(http.StatusOK, response)
}
// Helper functions for enhanced progress information
func (h *QuizHandler) getWorkerStatusForUser(ctx context.Context, userID int) (*api.WorkerStatus, error) {
// Get worker health information
workerHealth, err := h.workerService.GetWorkerHealth(ctx)
if err != nil {
return nil, err
}
// Check if user is paused
userPaused, err := h.workerService.IsUserPaused(ctx, userID)
if err != nil {
userPaused = false // Default to not paused if check fails
}
// Check if global pause is active
globalPaused, err := h.workerService.IsGlobalPaused(ctx)
if err != nil {
globalPaused = false // Default to not paused if check fails
}
// Determine worker status
var status api.WorkerStatusStatus
var errorMessage *string
if globalPaused {
status = api.WorkerStatusStatusIdle // Use idle for paused state
} else if userPaused {
status = api.WorkerStatusStatusIdle // Use idle for paused state
} else {
status = api.WorkerStatusStatusIdle // Default to idle
// Check for worker errors and actual activity
if workerInstances, ok := workerHealth["worker_instances"].([]map[string]interface{}); ok {
for _, instance := range workerInstances {
// Check for errors first
if lastError, hasError := instance["last_run_error"]; hasError && lastError != nil {
if errorStr, ok := lastError.(string); ok && errorStr != "" {
// For errors, we'll use idle status but set the error message
status = api.WorkerStatusStatusIdle
errorMessage = &errorStr
break
}
}
// Check if worker is running AND has recent activity
if isRunning, ok := instance["is_running"].(bool); ok && isRunning {
// Only set to busy if the worker is actually active (not just running but idle)
// We'll check if there's recent activity or if the worker is actively generating
if lastHeartbeat, hasHeartbeat := instance["last_heartbeat"]; hasHeartbeat && lastHeartbeat != nil {
if heartbeatStr, ok := lastHeartbeat.(string); ok {
if heartbeat, err := time.Parse(time.RFC3339, heartbeatStr); err == nil {
// Consider busy if heartbeat is very recent (within last 30 seconds)
if time.Since(heartbeat) < 30*time.Second {
status = api.WorkerStatusStatusBusy
}
}
}
}
}
}
}
}
// Get last heartbeat
var lastHeartbeat *time.Time
if workerInstances, ok := workerHealth["worker_instances"].([]map[string]interface{}); ok && len(workerInstances) > 0 {
if heartbeatStr, ok := workerInstances[0]["last_heartbeat"].(string); ok {
if heartbeat, err := time.Parse(time.RFC3339, heartbeatStr); err == nil {
lastHeartbeat = &heartbeat
}
}
}
return &api.WorkerStatus{
Status: &status,
LastHeartbeat: formatTimePointer(lastHeartbeat),
ErrorMessage: errorMessage,
}, nil
}
func (h *QuizHandler) getPriorityInsightsForUser(ctx context.Context, userID int) (*api.PriorityInsights, error) {
// Get priority distribution for the user
priorityDistribution, err := h.learningService.GetUserPriorityScoreDistribution(ctx, userID)
if err != nil {
return nil, err
}
// Extract counts from distribution
highCount := 0
mediumCount := 0
lowCount := 0
totalCount := 0
if high, ok := priorityDistribution["high"].(int); ok {
highCount = high
totalCount += high
}
if medium, ok := priorityDistribution["medium"].(int); ok {
mediumCount = medium
totalCount += medium
}
if low, ok := priorityDistribution["low"].(int); ok {
lowCount = low
totalCount += low
}
return &api.PriorityInsights{
TotalQuestionsInQueue: &totalCount,
HighPriorityQuestions: &highCount,
MediumPriorityQuestions: &mediumCount,
LowPriorityQuestions: &lowCount,
}, nil
}
func (h *QuizHandler) getGenerationFocusForUser(ctx context.Context, userID int) (*api.GenerationFocus, error) {
// Get user's AI configuration
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
return nil, err
}
// Get current generation model
model := "default"
if user.AIModel.Valid && user.AIModel.String != "" {
model = user.AIModel.String
}
// Get last generation time (simplified - could be enhanced with actual generation logs)
lastGenerationTime := time.Now().Add(-time.Hour) // Placeholder
// Get generation rate (simplified - could be enhanced with actual metrics)
generationRate := float32(2.5) // Placeholder: average questions per minute
return &api.GenerationFocus{
CurrentGenerationModel: &model,
LastGenerationTime: formatTimePtr(lastGenerationTime),
GenerationRate: &generationRate,
}, nil
}
func (h *QuizHandler) getHighPriorityTopicsForUser(ctx context.Context, userID int) ([]string, error) {
// Get high priority topics from learning service
topics, err := h.learningService.GetHighPriorityTopics(ctx, userID)
if err != nil {
return nil, err
}
return topics, nil
}
func (h *QuizHandler) getGapAnalysisForUser(ctx context.Context, userID int) (map[string]interface{}, error) {
// Get gap analysis from learning service
gapAnalysis, err := h.learningService.GetGapAnalysis(ctx, userID)
if err != nil {
return nil, err
}
return gapAnalysis, nil
}
func (h *QuizHandler) getPriorityDistributionForUser(ctx context.Context, userID int) (map[string]int, error) {
// Get priority distribution from learning service
distribution, err := h.learningService.GetPriorityDistribution(ctx, userID)
if err != nil {
return nil, err
}
return distribution, nil
}
func convertLearningPreferencesToAPI(prefs *models.UserLearningPreferences) *api.UserLearningPreferences {
out := &api.UserLearningPreferences{
FocusOnWeakAreas: prefs.FocusOnWeakAreas,
FreshQuestionRatio: float32(prefs.FreshQuestionRatio),
KnownQuestionPenalty: float32(prefs.KnownQuestionPenalty),
ReviewIntervalDays: prefs.ReviewIntervalDays,
WeakAreaBoost: float32(prefs.WeakAreaBoost),
DailyReminderEnabled: prefs.DailyReminderEnabled,
}
if prefs.TTSVoice != "" {
v := prefs.TTSVoice
out.TtsVoice = &v
}
if prefs.DailyGoal > 0 {
dg := prefs.DailyGoal
out.DailyGoal = &dg
}
return out
}
package handlers
import (
"fmt"
"net/http"
"sort"
"strings"
"time"
"quizapp/internal/observability"
"github.com/gin-gonic/gin"
)
// RouteInfo represents information about a single route
type RouteInfo struct {
Method string `json:"method"`
Path string `json:"path"`
HandlerName string `json:"handler_name"`
}
// RouteListingHandler generates automatic route listings
type RouteListingHandler struct {
serviceName string
routes []RouteInfo
}
// NewRouteListingHandler creates a new route listing handler
func NewRouteListingHandler(serviceName string) *RouteListingHandler {
return &RouteListingHandler{
serviceName: serviceName,
routes: []RouteInfo{},
}
}
// CollectRoutes extracts all routes from a Gin engine
func (h *RouteListingHandler) CollectRoutes(engine *gin.Engine) {
h.routes = []RouteInfo{}
// Get all routes from the Gin engine
routes := engine.Routes()
for _, route := range routes {
// Skip internal Gin routes
if strings.HasPrefix(route.Path, "/debug/") {
continue
}
h.routes = append(h.routes, RouteInfo{
Method: route.Method,
Path: route.Path,
HandlerName: route.Handler,
})
}
// Sort routes by path for better organization
sort.Slice(h.routes, func(i, j int) bool {
return h.routes[i].Path < h.routes[j].Path
})
}
// GetRouteListingPage shows all available routes as HTML
func (h *RouteListingHandler) GetRouteListingPage(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_route_listing_page")
defer observability.FinishSpan(span, nil)
html := h.generateHTML()
// Add no-cache headers
c.Header("Content-Type", "text/html; charset=utf-8")
c.Header("Cache-Control", "no-cache, no-store, must-revalidate")
c.Header("Pragma", "no-cache")
c.Header("Expires", "0")
c.String(http.StatusOK, html)
}
// GetRouteListingJSON returns the route listing as JSON
func (h *RouteListingHandler) GetRouteListingJSON(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_route_listing_json")
defer observability.FinishSpan(span, nil)
c.JSON(http.StatusOK, h.routes)
}
// generateHTML creates an HTML page listing all routes
func (h *RouteListingHandler) generateHTML() string {
var html strings.Builder
html.WriteString(`<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>` + h.serviceName + ` - Available Routes</title>
<style>
body { font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; line-height: 1.6; padding: 20px; background-color: #f8f9fa; color: #212529; }
.container { max-width: 1200px; margin: auto; background: #fff; padding: 30px; border-radius: 8px; box-shadow: 0 4px 8px rgba(0,0,0,0.05); }
h1 { color: #0056b3; border-bottom: 2px solid #dee2e6; padding-bottom: 10px; margin-bottom: 30px; }
.service-info { background: #e7f3ff; padding: 15px; border-radius: 5px; margin-bottom: 30px; }
.route-table { width: 100%; border-collapse: collapse; margin-bottom: 30px; }
.route-table th, .route-table td { padding: 12px; text-align: left; border-bottom: 1px solid #dee2e6; }
.route-table th { background-color: #f8f9fa; font-weight: 600; color: #495057; }
.route-table tr:nth-child(even) { background-color: #f8f9fa; }
.route-table tr:hover { background-color: #e9ecef; }
.method { display: inline-block; padding: 4px 8px; border-radius: 4px; font-size: 12px; font-weight: bold; min-width: 60px; text-align: center; }
.method-get { background-color: #d4edda; color: #155724; }
.method-post { background-color: #cce5ff; color: #004085; }
.method-put { background-color: #fff3cd; color: #856404; }
.method-delete { background-color: #f8d7da; color: #721c24; }
.method-patch { background-color: #e2e3e5; color: #383d41; }
.path { font-family: "Monaco", "Menlo", "Ubuntu Mono", monospace; font-size: 14px; color: #6f42c1; }
.clickable-path { cursor: pointer; text-decoration: underline; }
.clickable-path:hover { background-color: #f8f9fa; }
.footer { margin-top: 30px; text-align: center; color: #6c757d; font-size: 14px; }
.stats { display: flex; gap: 20px; margin-bottom: 20px; }
.stat-box { background: #ffffff; border: 1px solid #dee2e6; padding: 15px; border-radius: 5px; text-align: center; flex: 1; }
.stat-number { font-size: 24px; font-weight: bold; color: #0056b3; }
.stat-label { color: #6c757d; font-size: 14px; }
</style>
</head>
<body>
<div class="container">
<h1>` + h.serviceName + ` Service - Available Routes</h1>
<div class="service-info">
<strong>Service:</strong> ` + h.serviceName + `<br>
<strong>Generated:</strong> ` + time.Now().Format("2006-01-02 15:04:05") + `<br>
<strong>Total Routes:</strong> ` + fmt.Sprintf("%d", len(h.routes)) + `
</div>
<div class="stats">
<div class="stat-box">
<div class="stat-number">` + fmt.Sprintf("%d", len(h.routes)) + `</div>
<div class="stat-label">Total Routes</div>
</div>
<div class="stat-box">
<div class="stat-number">` + fmt.Sprintf("%d", h.countMethods("GET")) + `</div>
<div class="stat-label">GET Routes</div>
</div>
<div class="stat-box">
<div class="stat-number">` + fmt.Sprintf("%d", h.countMethods("POST")) + `</div>
<div class="stat-label">POST Routes</div>
</div>
</div>
<table class="route-table">
<thead>
<tr>
<th>Method</th>
<th>Path</th>
<th>Handler</th>
</tr>
</thead>
<tbody>`)
for _, route := range h.routes {
methodClass := "method-" + strings.ToLower(route.Method)
pathClass := "path"
// Make paths clickable for GET routes
if route.Method == "GET" {
pathClass += " clickable-path"
}
html.WriteString(fmt.Sprintf(`
<tr>
<td><span class="method %s">%s</span></td>
<td><span class="%s" onclick="navigateToRoute('%s', '%s')">%s</span></td>
<td>%s</td>
</tr>`,
methodClass, route.Method,
pathClass, route.Method, route.Path, route.Path,
route.HandlerName,
))
}
html.WriteString(`
</tbody>
</table>
<div class="footer">
<p>Click on any GET route path to navigate to it | <a href="/?json=true">View as JSON</a></p>
</div>
</div>
<script>
function navigateToRoute(method, path) {
if (method === 'GET') {
window.location.href = path;
} else {
alert('Only GET routes can be navigated to directly. Use API client for ' + method + ' requests.');
}
}
</script>
</body>
</html>`)
return html.String()
}
// countMethods counts routes by HTTP method
func (h *RouteListingHandler) countMethods(method string) int {
count := 0
for _, route := range h.routes {
if route.Method == method {
count++
}
}
return count
}
package handlers
import (
"encoding/json"
"net/http"
"os"
"strings"
"time"
"github.com/gin-contrib/cors"
"github.com/gin-contrib/secure"
"github.com/gin-contrib/sessions"
"github.com/gin-contrib/sessions/cookie"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"quizapp/internal/config"
"quizapp/internal/middleware"
"quizapp/internal/observability"
"quizapp/internal/services"
"quizapp/internal/version"
)
// IMPORTANT: When adding new API endpoints, make sure to:
// 1. Add them to swagger.yaml with proper documentation
// 2. Run `task generate-api-types` to regenerate types
// 3. Update any relevant tests
// 4. Consider if the endpoint should be public or admin-only
// NewRouter creates a new router factory with all the necessary middleware and routes
func NewRouter(
cfg *config.Config,
userService services.UserServiceInterface,
questionService services.QuestionServiceInterface,
learningService services.LearningServiceInterface,
aiService services.AIServiceInterface,
workerService services.WorkerServiceInterface,
dailyQuestionService services.DailyQuestionServiceInterface,
oauthService *services.OAuthService,
generationHintService services.GenerationHintServiceInterface,
logger *observability.Logger,
) *gin.Engine {
// Setup Gin router
router := gin.New()
router.Use(gin.Recovery())
// Add HTTP request logging middleware using our observability logger
router.Use(func(c *gin.Context) {
start := time.Now()
// Process request
c.Next()
// Log request details using our observability logger
latency := time.Since(start)
statusCode := c.Writer.Status()
clientIP := c.ClientIP()
method := c.Request.Method
path := c.Request.URL.Path
// Create structured log entry
fields := map[string]interface{}{
"http.method": method,
"http.path": path,
"http.status_code": statusCode,
"http.latency_ms": latency.Milliseconds(),
"http.client_ip": clientIP,
"http.user_agent": c.Request.UserAgent(),
}
// Add error message if present
if len(c.Errors) > 0 {
fields["http.error"] = c.Errors.String()
}
// For failed requests (4xx and 5xx), capture response body for debugging
if statusCode >= 400 {
// Get response body for error requests
if c.Writer.Size() > 0 {
// Try to capture response body for debugging
// Note: This is a best effort since the response may have already been written
fields["http.response_size"] = c.Writer.Size()
}
// Add more context for 5xx errors
if statusCode >= 500 {
fields["http.error_type"] = "server_error"
// Log additional context that might help debugging
if c.Request.Body != nil {
fields["http.request_has_body"] = true
}
} else {
fields["http.error_type"] = "client_error"
}
}
// Log using our observability logger (goes to both stdout and OTLP)
// Use appropriate log level based on status code
if statusCode >= 500 {
logger.Error(c.Request.Context(), "HTTP request failed", nil, fields)
} else if statusCode >= 400 {
logger.Warn(c.Request.Context(), "HTTP request warning", fields)
} else {
logger.Info(c.Request.Context(), "HTTP request", fields)
}
})
// Health check endpoint (defined before any middleware)
router.GET("/health", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"status": "ok", "service": "backend"})
})
// Add OpenTelemetry middleware for HTTP tracing and context propagation with automatic error attributes
router.Use(observability.GinMiddlewareWithErrorHandling("quiz-backend"))
// Add response validation middleware for API endpoints
router.Use(middleware.ResponseValidationMiddleware(logger))
// Swagger documentation (defined before middleware)
router.StaticFile("/swagger.yaml", "./swagger.yaml")
router.StaticFile("/swaggerz", "./swaggerz.html")
// Disable automatic redirection for trailing slashes, which is better for APIs
router.RedirectTrailingSlash = false
// Setup CORS middleware
corsConfig := cors.DefaultConfig()
corsConfig.AllowOrigins = cfg.Server.CORSOrigins
corsConfig.AllowCredentials = true
corsConfig.AllowHeaders = []string{"Origin", "Content-Length", "Content-Type", "Authorization", "X-Requested-With"}
corsConfig.AllowMethods = []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}
router.Use(cors.New(corsConfig))
// Setup session middleware
store := cookie.NewStore([]byte(cfg.Server.SessionSecret))
// Configure session options for security
sessionOpts := sessions.Options{
Path: config.SessionPath,
MaxAge: int(config.SessionMaxAge.Seconds()),
HttpOnly: config.SessionHTTPOnly,
Secure: config.SessionSecure, // Set to true in production with HTTPS
}
if cfg.Server.Debug {
sessionOpts.SameSite = http.SameSiteDefaultMode
} else {
sessionOpts.SameSite = http.SameSiteLaxMode
sessionOpts.Secure = true
}
store.Options(sessionOpts)
router.Use(sessions.Sessions(config.SessionName, store))
// Setup Gin mode
gin.SetMode(gin.ReleaseMode)
if cfg.Server.Debug {
gin.SetMode(gin.DebugMode)
}
// Security middleware
secureConfig := secure.DefaultConfig()
secureConfig.SSLRedirect = false
secureConfig.ContentSecurityPolicy = config.DefaultCSP
router.Use(secure.New(secureConfig))
// Serve all static assets (JS, fonts, CSS, etc.) from /backend/*filepath
// Note: Static assets are now served from the frontend build
// Initialize handlers
authHandler := NewAuthHandler(userService, oauthService, cfg, logger)
emailService := services.CreateEmailService(cfg, logger)
settingsHandler := NewSettingsHandler(userService, aiService, learningService, emailService, cfg, logger)
quizHandler := NewQuizHandler(userService, questionService, aiService, learningService, workerService, generationHintService, cfg, logger)
dailyQuestionHandler := NewDailyQuestionHandler(userService, dailyQuestionService, cfg, logger)
adminHandler := NewAdminHandlerWithLogger(userService, questionService, aiService, cfg, learningService, workerService, logger)
userAdminHandler := NewUserAdminHandler(userService, cfg, logger)
// V1 routes (matching swagger spec)
v1 := router.Group("/v1")
{
// Version aggregation endpoint (no auth)
v1.GET("/version", func(c *gin.Context) {
backendVersion := gin.H{
"service": "backend",
"version": version.Version,
"commit": version.Commit,
"buildTime": version.BuildTime,
}
workerInternalURL := os.Getenv("WORKER_INTERNAL_URL")
if workerInternalURL == "" {
workerInternalURL = cfg.Server.WorkerInternalURL // fallback
}
// Use instrumented HTTP client for tracing
client := &http.Client{
Transport: otelhttp.NewTransport(http.DefaultTransport),
}
req, err := http.NewRequest("GET", workerInternalURL+"/v1/version", nil)
var workerResp *http.Response
if err == nil {
req = req.WithContext(c.Request.Context())
workerResp, err = client.Do(req)
}
var workerVersion interface{}
if err == nil && workerResp.StatusCode == http.StatusOK {
defer func() { _ = workerResp.Body.Close() }()
if err := json.NewDecoder(workerResp.Body).Decode(&workerVersion); err != nil {
workerVersion = gin.H{"error": "Failed to decode worker version"}
}
} else {
workerVersion = gin.H{"error": "Worker unavailable"}
}
c.JSON(http.StatusOK, gin.H{
"backend": backendVersion,
"worker": workerVersion,
})
})
auth := v1.Group("/auth")
{
auth.POST("/login", middleware.RequestValidationMiddleware(logger), authHandler.Login)
auth.POST("/logout", authHandler.Logout)
auth.GET("/status", authHandler.Status)
auth.GET("/check", middleware.RequireAuth(), authHandler.Check)
auth.POST("/signup", middleware.RequestValidationMiddleware(logger), authHandler.Signup)
auth.GET("/signup/status", authHandler.SignupStatus)
auth.GET("/google/login", authHandler.GoogleLogin)
auth.GET("/google/callback", authHandler.GoogleCallback)
}
quiz := v1.Group("/quiz")
quiz.Use(middleware.RequireAuth())
quiz.Use(middleware.RequestValidationMiddleware(logger))
{
quiz.GET("/question", quizHandler.GetQuestion)
quiz.GET("/question/:id", quizHandler.GetQuestion)
quiz.POST("/question/:id/report", quizHandler.ReportQuestion)
quiz.POST("/question/:id/mark-known", quizHandler.MarkQuestionAsKnown)
quiz.POST("/answer", quizHandler.SubmitAnswer)
quiz.GET("/progress", quizHandler.GetProgress)
quiz.GET("/worker-status", quizHandler.GetWorkerStatus)
quiz.POST("/chat/stream", quizHandler.ChatStream)
}
daily := v1.Group("/daily")
daily.Use(middleware.RequireAuth())
daily.Use(middleware.RequestValidationMiddleware(logger))
{
daily.GET("/questions/:date", dailyQuestionHandler.GetDailyQuestions)
daily.POST("/questions/:date/complete/:questionId", dailyQuestionHandler.MarkQuestionCompleted)
daily.DELETE("/questions/:date/complete/:questionId", dailyQuestionHandler.ResetQuestionCompleted)
daily.POST("/questions/:date/answer/:questionId", dailyQuestionHandler.SubmitDailyQuestionAnswer)
daily.GET("/history/:questionId", dailyQuestionHandler.GetQuestionHistory)
daily.GET("/dates", dailyQuestionHandler.GetAvailableDates)
daily.GET("/progress/:date", dailyQuestionHandler.GetDailyProgress)
// Note: Assignment is handled automatically by the worker
}
settings := v1.Group("/settings")
{
settings.GET("/ai-providers", middleware.RequireAuth(), settingsHandler.GetProviders)
settings.GET("/levels", settingsHandler.GetLevels)
settings.GET("/languages", settingsHandler.GetLanguages)
settings.POST("/test-ai", middleware.RequireAuth(), middleware.RequestValidationMiddleware(logger), settingsHandler.TestAIConnection)
settings.POST("/test-email", middleware.RequireAuth(), middleware.RequestValidationMiddleware(logger), settingsHandler.SendTestEmail)
settings.PUT("", middleware.RequireAuth(), middleware.RequestValidationMiddleware(logger), settingsHandler.UpdateUserSettings)
settings.GET("/api-key/:provider", middleware.RequireAuth(), settingsHandler.CheckAPIKeyAvailability)
}
preferences := v1.Group("/preferences")
preferences.Use(middleware.RequireAuth())
preferences.Use(middleware.RequestValidationMiddleware(logger))
{
preferences.GET("/learning", settingsHandler.GetLearningPreferences)
preferences.PUT("/learning", settingsHandler.UpdateLearningPreferences)
}
// User management endpoints (non-admin)
userz := v1.Group("/userz")
{
userz.PUT("/profile", middleware.RequireAuth(), middleware.RequestValidationMiddleware(logger), userAdminHandler.UpdateCurrentUserProfile)
}
// Admin endpoints
admin := v1.Group("/admin")
admin.Use(middleware.RequireAdmin(userService))
admin.Use(middleware.RequestValidationMiddleware(logger))
{
// Backend admin endpoints
backend := admin.Group("/backend")
{
// Backend admin page
backend.GET("", adminHandler.GetBackendAdminPage)
// User management (admin only)
backend.GET("/userz", userAdminHandler.GetAllUsers)
backend.GET("/userz/paginated", userAdminHandler.GetUsersPaginated)
backend.POST("/userz", userAdminHandler.CreateUser)
backend.PUT("/userz/:id", userAdminHandler.UpdateUser)
backend.DELETE("/userz/:id", userAdminHandler.DeleteUser)
backend.POST("/userz/:id/reset-password", userAdminHandler.ResetUserPassword)
// Role management endpoints
backend.GET("/roles", adminHandler.GetRoles)
backend.GET("/userz/:id/roles", adminHandler.GetUserRoles)
backend.POST("/userz/:id/roles", adminHandler.AssignRole)
backend.DELETE("/userz/:id/roles/:roleId", adminHandler.RemoveRole)
// Admin dashboard data
backend.GET("/dashboard", adminHandler.GetBackendAdminData)
backend.GET("/ai-concurrency", adminHandler.GetAIConcurrencyStats)
// Question management
backend.GET("/questions/:id", adminHandler.GetQuestion)
backend.GET("/questions/:id/users", adminHandler.GetUsersForQuestion)
backend.PUT("/questions/:id", adminHandler.UpdateQuestion)
backend.DELETE("/questions/:id", adminHandler.DeleteQuestion)
backend.POST("/questions/:id/assign-users", adminHandler.AssignUsersToQuestion)
backend.POST("/questions/:id/unassign-users", adminHandler.UnassignUsersFromQuestion)
backend.GET("/questions/paginated", adminHandler.GetQuestionsPaginated)
backend.GET("/questions", adminHandler.GetAllQuestions)
backend.GET("/reported-questions", adminHandler.GetReportedQuestionsPaginated)
backend.POST("/questions/:id/fix", adminHandler.MarkQuestionAsFixed)
backend.POST("/questions/:id/ai-fix", adminHandler.FixQuestionWithAI)
// Data management
backend.POST("/clear-user-data", adminHandler.ClearUserData)
backend.POST("/clear-database", adminHandler.ClearDatabase)
backend.POST("/userz/:id/clear", adminHandler.ClearUserDataForUser)
}
}
}
// Config dump endpoint
router.GET("/configz", adminHandler.GetConfigz)
// Serve frontend static files
router.Static("/assets", "./frontend/dist/assets")
router.StaticFile("/favicon.svg", "./frontend/dist/favicon.svg")
router.StaticFile("/fonts", "./frontend/dist/fonts")
// Catch-all route for SPA - serve index.html for any route that doesn't match API routes
router.NoRoute(func(c *gin.Context) {
// Don't serve index.html for API routes
if strings.HasPrefix(c.Request.URL.Path, "/v1/") ||
strings.HasPrefix(c.Request.URL.Path, "/configz") ||
strings.HasPrefix(c.Request.URL.Path, "/swagger") ||
strings.HasPrefix(c.Request.URL.Path, "/backend/") {
c.JSON(http.StatusNotFound, gin.H{"error": "Not found"})
return
}
// Serve the frontend's index.html for all other routes
c.File("./frontend/dist/index.html")
})
// Automatic route listing at root path
routeListing := NewRouteListingHandler("Backend")
routeListing.CollectRoutes(router)
// Root path shows all available routes
router.GET("/", func(c *gin.Context) {
if c.Query("json") == "true" {
routeListing.GetRouteListingJSON(c)
} else {
routeListing.GetRouteListingPage(c)
}
})
return router
}
package handlers
import (
"quizapp/internal/middleware"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
// GetUserIDFromSession retrieves the current user ID from the session.
// Returns (0, false) if not authenticated or if the stored value is invalid.
func GetUserIDFromSession(c *gin.Context) (int, bool) {
session := sessions.Default(c)
userID := session.Get(middleware.UserIDKey)
if userID == nil {
return 0, false
}
id, ok := userID.(int)
if !ok {
return 0, false
}
return id, true
}
package handlers
import (
"fmt"
"net/http"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/middleware"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
"quizapp/internal/services/mailer"
contextutils "quizapp/internal/utils"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
// SettingsHandler handles user settings related HTTP requests
type SettingsHandler struct {
userService services.UserServiceInterface
aiService services.AIServiceInterface
learningService services.LearningServiceInterface
emailService mailer.Mailer
cfg *config.Config
logger *observability.Logger
}
// NewSettingsHandler creates a new SettingsHandler instance
func NewSettingsHandler(userService services.UserServiceInterface, aiService services.AIServiceInterface, learningService services.LearningServiceInterface, emailService mailer.Mailer, cfg *config.Config, logger *observability.Logger) *SettingsHandler {
return &SettingsHandler{
userService: userService,
aiService: aiService,
learningService: learningService,
emailService: emailService,
cfg: cfg,
logger: logger,
}
}
// UpdateUserSettings handles updating user settings
func (h *SettingsHandler) UpdateUserSettings(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "update_user_settings")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var settings api.UserSettings
if err := c.ShouldBindJSON(&settings); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Validate that at least one meaningful field is provided
// Avoid relying on generated union/raw fields that may be non-nil for an empty JSON body
hasAnyField := settings.Language != nil ||
settings.Level != nil ||
settings.AiProvider != nil ||
settings.AiModel != nil ||
settings.ApiKey != nil ||
settings.AiEnabled != nil
if !hasAnyField {
HandleAppError(c, contextutils.ErrInvalidInput)
return
}
// Convert api.UserSettings to models.UserSettings
modelSettings := models.UserSettings{}
if settings.Language != nil {
modelSettings.Language = string(*settings.Language)
span.SetAttributes(attribute.String("settings.language", modelSettings.Language))
}
if settings.Level != nil {
modelSettings.Level = string(*settings.Level)
span.SetAttributes(attribute.String("settings.level", modelSettings.Level))
}
if settings.AiProvider != nil {
modelSettings.AIProvider = *settings.AiProvider
span.SetAttributes(attribute.String("settings.ai_provider", modelSettings.AIProvider))
}
if settings.AiModel != nil {
modelSettings.AIModel = *settings.AiModel
span.SetAttributes(attribute.String("settings.ai_model", modelSettings.AIModel))
}
if settings.ApiKey != nil {
modelSettings.AIAPIKey = *settings.ApiKey
span.SetAttributes(attribute.Bool("settings.api_key_provided", true))
}
if settings.AiEnabled != nil {
modelSettings.AIEnabled = *settings.AiEnabled
span.SetAttributes(attribute.Bool("settings.ai_enabled", modelSettings.AIEnabled))
}
// Validate level if provided (including empty string)
if settings.Level != nil {
validLevels := h.cfg.GetAllLevels()
isValidLevel := false
for _, level := range validLevels {
if modelSettings.Level == level {
isValidLevel = true
break
}
}
if !isValidLevel {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
}
// Validate language if provided (including empty string)
if settings.Language != nil {
validLanguages := h.cfg.GetLanguages()
isValidLanguage := false
for _, language := range validLanguages {
if modelSettings.Language == language {
isValidLanguage = true
break
}
}
if !isValidLanguage {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
}
if err := h.userService.UpdateUserSettings(c.Request.Context(), userID, &modelSettings); err != nil {
// Check if the error is due to user not found
if contextutils.IsError(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to update settings"))
return
}
c.JSON(http.StatusOK, api.SuccessResponse{Success: true})
}
// TestAIConnection tests the AI service connection with provided settings
func (h *SettingsHandler) TestAIConnection(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "test_ai_connection")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var req api.TestAIRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request format",
"",
err,
))
return
}
// Extract values from API request
provider := req.Provider
model := req.Model
apiKey := ""
if req.ApiKey != nil {
apiKey = *req.ApiKey
}
// If API key is empty, try to use the saved one from the new user_api_keys table
if apiKey == "" {
savedKey, err := h.userService.GetUserAPIKey(c.Request.Context(), userID, provider)
if err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to get saved API key"))
return
}
apiKey = savedKey
}
err := h.aiService.TestConnection(c.Request.Context(), provider, model, apiKey)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": fmt.Sprintf("Model '%s': %s", model, err.Error()),
})
return
}
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Connection successful"})
}
// GetProviders returns the available AI provider configurations
func (h *SettingsHandler) GetProviders(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_providers")
defer observability.FinishSpan(span, nil)
response := gin.H{
"providers": h.cfg.Providers,
"levels": h.cfg.GetAllLevels(),
"languages": h.cfg.GetLanguages(),
}
c.JSON(http.StatusOK, response)
}
// GetLevels returns the available levels and their descriptions.
func (h *SettingsHandler) GetLevels(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_levels")
defer observability.FinishSpan(span, nil)
language := c.Query("language")
if language != "" {
levels := h.cfg.GetLevelsForLanguage(language)
descriptions := h.cfg.GetLevelDescriptionsForLanguage(language)
c.JSON(http.StatusOK, gin.H{
"levels": levels,
"level_descriptions": descriptions,
})
return
}
c.JSON(http.StatusOK, gin.H{
"levels": h.cfg.GetAllLevels(),
"level_descriptions": h.cfg.GetAllLevelDescriptions(),
})
}
// GetLanguages returns the available languages.
func (h *SettingsHandler) GetLanguages(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_languages")
defer observability.FinishSpan(span, nil)
c.JSON(http.StatusOK, h.cfg.GetLanguages())
}
// CheckAPIKeyAvailability checks if the user has a saved API key for the specified provider
func (h *SettingsHandler) CheckAPIKeyAvailability(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "check_api_key_availability")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
provider := c.Param("provider")
if provider == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Check if user has a saved API key for this provider
hasAPIKey, err := h.userService.HasUserAPIKey(ctx, userID, provider)
if err != nil {
h.logger.Error(ctx, "Failed to check API key availability", err, map[string]interface{}{
"user_id": userID,
"provider": provider,
})
HandleAppError(c, contextutils.WrapError(err, "failed to check API key availability"))
return
}
c.JSON(http.StatusOK, gin.H{"has_api_key": hasAPIKey})
}
// GetLearningPreferences retrieves user learning preferences
func (h *SettingsHandler) GetLearningPreferences(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_learning_preferences")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
preferences, err := h.learningService.GetUserLearningPreferences(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get learning preferences", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get learning preferences"))
return
}
// Convert backend model to API schema
apiPreferences := convertLearningPreferencesToAPI(preferences)
c.JSON(http.StatusOK, apiPreferences)
}
// UpdateLearningPreferences updates user learning preferences
func (h *SettingsHandler) UpdateLearningPreferences(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "update_learning_preferences")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var req models.UserLearningPreferences
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Set the user ID
req.UserID = userID
// Set span attributes for updated preferences
span.SetAttributes(
attribute.Bool("learning.focus_on_weak_areas", req.FocusOnWeakAreas),
attribute.Bool("learning.include_review_questions", req.IncludeReviewQuestions),
attribute.Float64("learning.fresh_question_ratio", req.FreshQuestionRatio),
attribute.Float64("learning.known_question_penalty", req.KnownQuestionPenalty),
attribute.Int("learning.review_interval_days", req.ReviewIntervalDays),
attribute.Float64("learning.weak_area_boost", req.WeakAreaBoost),
)
// Update preferences in database
updatedPrefs, err := h.learningService.UpdateUserLearningPreferences(ctx, userID, &req)
if err != nil {
h.logger.Error(ctx, "Failed to update learning preferences", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to update learning preferences"))
return
}
// Convert backend model to API schema and return
apiPreferences := convertLearningPreferencesToAPI(updatedPrefs)
c.JSON(http.StatusOK, apiPreferences)
}
// SendTestEmail sends a test email to the current user
func (h *SettingsHandler) SendTestEmail(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "send_test_email")
defer observability.FinishSpan(span, nil)
session := sessions.Default(c)
userID, ok := session.Get(middleware.UserIDKey).(int)
if !ok {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
// Get the current user
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for test email", err, map[string]interface{}{
"user_id": userID,
})
HandleAppError(c, contextutils.WrapError(err, "failed to get user information"))
return
}
// Check if user has an email address
if !user.Email.Valid || user.Email.String == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Check if email service is enabled
if !h.emailService.IsEnabled() {
HandleAppError(c, contextutils.ErrServiceUnavailable)
return
}
// Send test email
err = h.emailService.SendEmail(ctx, user.Email.String, "Test Email from Quiz App", "test_email", map[string]interface{}{
"Username": user.Username,
"TestTime": "now",
"Message": "This is a test email to verify your email settings are working correctly.",
})
if err != nil {
h.logger.Error(ctx, "Failed to send test email", err, map[string]interface{}{
"user_id": userID,
"email": user.Email.String,
})
HandleAppError(c, contextutils.WrapError(err, "failed to send test email"))
return
}
h.logger.Info(ctx, "Test email sent successfully", map[string]interface{}{
"user_id": userID,
"email": user.Email.String,
})
c.JSON(http.StatusOK, api.SuccessResponse{Success: true, Message: stringPtr("Test email sent successfully")})
}
//go:build integration
// +build integration
package handlers
import (
"context"
"encoding/json"
"strings"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
)
// MockAIService implements AIServiceInterface for testing
type MockAIService struct {
realService *services.AIService
}
func NewMockAIService(cfg *config.Config, logger *observability.Logger) *MockAIService {
return &MockAIService{
realService: services.NewAIService(cfg, logger),
}
}
// TestConnection returns a mock response for AI connection tests
func (m *MockAIService) TestConnection(ctx context.Context, provider, model, apiKey string) error {
// For testing purposes, return success for valid-looking inputs
if provider != "" && model != "" {
// If it's a test API key, return an error to simulate failure
if strings.Contains(apiKey, "test") || apiKey == "" {
return contextutils.ErrorWithContextf("invalid API key")
}
return nil
}
return contextutils.ErrorWithContextf("missing provider or model")
}
// CallWithPrompt returns a mock response for AI fix requests, otherwise delegates to real service
func (m *MockAIService) CallWithPrompt(ctx context.Context, userConfig *services.UserAIConfig, prompt, grammar string) (string, error) {
// Check if this is an AI fix request by looking for fix-related keywords in the prompt
if strings.Contains(prompt, "fix") || strings.Contains(prompt, "Fix") ||
strings.Contains(prompt, "problematic") || strings.Contains(prompt, "report") {
// Return a mock AI fix response
mockResponse := map[string]interface{}{
"content": map[string]interface{}{
"question": "What is the capital of France?",
"options": []string{"Paris", "London", "Berlin", "Madrid"},
"correct_answer": 0,
"explanation": "Paris is the capital and largest city of France.",
},
"correct_answer": 0,
"explanation": "Paris is the capital and largest city of France.",
"change_reason": "Fixed grammar and improved clarity of the question.",
}
responseJSON, err := json.Marshal(mockResponse)
if err != nil {
return "", err
}
return string(responseJSON), nil
}
// For non-fix requests, delegate to the real service
if m.realService != nil {
return m.realService.CallWithPrompt(ctx, userConfig, prompt, grammar)
}
// Fallback response
return `{"response": "Mock AI response"}`, nil
}
// Implement other required methods by delegating to real service or returning defaults
func (m *MockAIService) GenerateQuestion(ctx context.Context, userConfig *services.UserAIConfig, req *models.AIQuestionGenRequest) (*models.Question, error) {
if m.realService != nil {
return m.realService.GenerateQuestion(ctx, userConfig, req)
}
return nil, contextutils.ErrorWithContextf("GenerateQuestion not implemented in mock")
}
func (m *MockAIService) GenerateQuestions(ctx context.Context, userConfig *services.UserAIConfig, req *models.AIQuestionGenRequest) ([]*models.Question, error) {
if m.realService != nil {
return m.realService.GenerateQuestions(ctx, userConfig, req)
}
return nil, contextutils.ErrorWithContextf("GenerateQuestions not implemented in mock")
}
func (m *MockAIService) GenerateQuestionsStream(ctx context.Context, userConfig *services.UserAIConfig, req *models.AIQuestionGenRequest, progress chan<- *models.Question, variety *services.VarietyElements) error {
if m.realService != nil {
return m.realService.GenerateQuestionsStream(ctx, userConfig, req, progress, variety)
}
return contextutils.ErrorWithContextf("GenerateQuestionsStream not implemented in mock")
}
func (m *MockAIService) GenerateChatResponse(ctx context.Context, userConfig *services.UserAIConfig, req *models.AIChatRequest) (string, error) {
if m.realService != nil {
return m.realService.GenerateChatResponse(ctx, userConfig, req)
}
return "Mock chat response", nil
}
func (m *MockAIService) GenerateChatResponseStream(ctx context.Context, userConfig *services.UserAIConfig, req *models.AIChatRequest, chunks chan<- string) error {
if m.realService != nil {
return m.realService.GenerateChatResponseStream(ctx, userConfig, req, chunks)
}
select {
case chunks <- "Mock streaming response":
default:
}
return nil
}
func (m *MockAIService) GetConcurrencyStats() services.ConcurrencyStats {
if m.realService != nil {
return m.realService.GetConcurrencyStats()
}
return services.ConcurrencyStats{}
}
func (m *MockAIService) GetQuestionBatchSize(provider string) int {
if m.realService != nil {
return m.realService.GetQuestionBatchSize(provider)
}
return 1
}
func (m *MockAIService) VarietyService() *services.VarietyService {
if m.realService != nil {
return m.realService.VarietyService()
}
return nil
}
func (m *MockAIService) TemplateManager() *services.AITemplateManager {
if m.realService != nil {
return m.realService.TemplateManager()
}
return nil
}
func (m *MockAIService) SupportsGrammarField(provider string) bool {
if m.realService != nil {
return m.realService.SupportsGrammarField(provider)
}
return false
}
func (m *MockAIService) Shutdown(ctx context.Context) error {
if m.realService != nil {
return m.realService.Shutdown(ctx)
}
return nil
}
package handlers
import (
"context"
"database/sql"
"errors"
"fmt"
"html/template"
"net/http"
"strconv"
"strings"
"time"
"quizapp/internal/api"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
)
// UserAdminHandler handles user management operations
type UserAdminHandler struct {
userService services.UserServiceInterface
cfg *config.Config
templates *template.Template
logger *observability.Logger
}
// NewUserAdminHandler creates a new UserAdminHandler instance
func NewUserAdminHandler(userService services.UserServiceInterface, cfg *config.Config, logger *observability.Logger) *UserAdminHandler {
return &UserAdminHandler{
userService: userService,
cfg: cfg,
templates: nil,
logger: logger,
}
}
// UserCreateRequest represents a request to create a new user
// Using the generated type from api package for automatic validation
type UserCreateRequest = api.UserCreateRequest
// UserUpdateRequest represents a request to update user profile
// Using the generated type from api package for automatic validation
type UserUpdateRequest = api.UserUpdateRequest
// PasswordResetRequest represents a request to reset user password
// Using the generated type from api package for automatic validation
type PasswordResetRequest = api.PasswordResetRequest
// ProfileResponse represents user profile data
type ProfileResponse struct {
ID int `json:"id"`
Username string `json:"username"`
Email *string `json:"email"`
Timezone *string `json:"timezone"`
LastActive *time.Time `json:"last_active"`
PreferredLanguage *string `json:"preferred_language"`
CurrentLevel *string `json:"current_level"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
AIEnabled bool `json:"ai_enabled"`
AIProvider *string `json:"ai_provider"`
AIModel *string `json:"ai_model"`
Roles []models.Role `json:"roles,omitempty"`
IsPaused bool `json:"is_paused"`
}
// GetAllUsers handles GET /userz - list all users (admin only) - JSON API
func (h *UserAdminHandler) GetAllUsers(c *gin.Context) {
users, err := h.userService.GetAllUsers(c.Request.Context())
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving users", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to retrieve users"))
return
}
// Convert to response format
var userResponses []ProfileResponse
for _, user := range users {
userResponses = append(userResponses, h.convertUserToProfileResponse(c.Request.Context(), &user))
}
c.JSON(http.StatusOK, gin.H{"users": userResponses})
}
// GetUsersPaginated handles GET /userz/paginated - list users with pagination (admin only)
func (h *UserAdminHandler) GetUsersPaginated(c *gin.Context) {
// Parse pagination parameters
page, pageSize := h.parsePagination(c)
// Parse filters
search := c.Query("search")
language := c.Query("language")
level := c.Query("level")
aiProvider := c.Query("ai_provider")
aiModel := c.Query("ai_model")
aiEnabled := c.Query("ai_enabled")
active := c.Query("active")
// Get paginated users from service
var users []models.User
var total int
var err error
users, total, err = h.userService.GetUsersPaginated(
c.Request.Context(),
page,
pageSize,
search,
language,
level,
aiProvider,
aiModel,
aiEnabled,
active,
)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving paginated users", err, map[string]interface{}{
"page": page,
"page_size": pageSize,
"search": search,
})
HandleAppError(c, contextutils.WrapError(err, "failed to retrieve users"))
return
}
// Convert to response format
var userResponses []ProfileResponse
for _, user := range users {
userResponses = append(userResponses, h.convertUserToProfileResponse(c.Request.Context(), &user))
}
// Calculate pagination info
totalPages := (total + pageSize - 1) / pageSize
c.JSON(http.StatusOK, gin.H{
"users": userResponses,
"pagination": gin.H{
"page": page,
"page_size": pageSize,
"total": total,
"total_pages": totalPages,
},
})
}
// parsePagination parses pagination parameters from the request
func (h *UserAdminHandler) parsePagination(c *gin.Context) (page, pageSize int) {
page = 1
pageSize = 20
if pageStr := c.Query("page"); pageStr != "" {
if p, err := strconv.Atoi(pageStr); err == nil && p > 0 {
page = p
}
}
if pageSizeStr := c.Query("page_size"); pageSizeStr != "" {
if ps, err := strconv.Atoi(pageSizeStr); err == nil && ps > 0 && ps <= 100 {
pageSize = ps
}
}
return page, pageSize
}
// CreateUser handles POST /userz - create new user (admin only)
func (h *UserAdminHandler) CreateUser(c *gin.Context) {
var req UserCreateRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request data",
"",
err,
))
return
}
// Validate required fields
if req.Username == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
if req.Password == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Extract values from generated types
timezone := "UTC"
if req.Timezone != nil && *req.Timezone != "" {
timezone = *req.Timezone
// Validate timezone if provided
if !h.isValidTimezone(timezone) {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
}
preferredLanguage := "italian"
if req.PreferredLanguage != nil && *req.PreferredLanguage != "" {
preferredLanguage = *req.PreferredLanguage
}
currentLevel := "A1"
if req.CurrentLevel != nil && *req.CurrentLevel != "" {
currentLevel = *req.CurrentLevel
}
email := ""
if req.Email != nil {
email = string(*req.Email)
}
// Check if username already exists
existingUser, err := h.userService.GetUserByUsername(c.Request.Context(), req.Username)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking existing username", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check existing username"))
return
}
if existingUser != nil {
HandleAppError(c, contextutils.ErrRecordExists)
return
}
// Check if email already exists (if provided)
if email != "" {
existingUser, err := h.userService.GetUserByEmail(c.Request.Context(), email)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking existing email", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check email uniqueness"))
return
}
if existingUser != nil {
HandleAppError(c, contextutils.ErrRecordExists)
return
}
}
// Create user
user, err := h.userService.CreateUserWithEmailAndTimezone(
c.Request.Context(),
req.Username,
email,
timezone,
preferredLanguage,
currentLevel,
)
if err != nil {
h.logger.Error(c.Request.Context(), "Error creating user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to create user"))
return
}
// Set password
err = h.userService.UpdateUserPassword(c.Request.Context(), user.ID, req.Password)
if err != nil {
h.logger.Error(c.Request.Context(), "Error setting user password", err, nil)
// Try to clean up the created user
_ = h.userService.DeleteUser(c.Request.Context(), user.ID)
HandleAppError(c, contextutils.WrapError(err, "failed to set user password"))
return
}
// Return the created user profile
c.JSON(http.StatusCreated, gin.H{
"message": "User created successfully",
"user": h.convertUserToProfileResponse(c.Request.Context(), user),
})
}
// UpdateUser handles PUT /userz/:id - update user details (admin or self)
func (h *UserAdminHandler) UpdateUser(c *gin.Context) {
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "database error"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Check authorization (admin or self) - skip for direct routes (testing)
if currentUserID, err := GetCurrentUserID(c); err == nil {
if err := RequireSelfOrAdmin(c.Request.Context(), h.userService, currentUserID, userID); err != nil {
if contextutils.IsError(err, contextutils.ErrForbidden) {
HandleAppError(c, contextutils.ErrForbidden)
return
}
h.logger.Error(c.Request.Context(), "Error checking authorization", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check authorization"))
return
}
}
var req UserUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request data",
"",
err,
))
return
}
// Validate timezone if provided
if req.Timezone != nil && *req.Timezone != "" && !h.isValidTimezone(*req.Timezone) {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Use existing values if not provided in request
username := user.Username
if req.Username != nil && *req.Username != "" {
username = *req.Username
}
email := ""
if user.Email.Valid {
email = user.Email.String
}
if req.Email != nil {
email = string(*req.Email)
}
timezone := ""
if user.Timezone.Valid {
timezone = user.Timezone.String
}
if req.Timezone != nil && *req.Timezone != "" {
timezone = *req.Timezone
}
preferredLanguage := ""
if user.PreferredLanguage.Valid {
preferredLanguage = user.PreferredLanguage.String
}
if req.PreferredLanguage != nil && *req.PreferredLanguage != "" {
preferredLanguage = *req.PreferredLanguage
}
currentLevel := ""
if user.CurrentLevel.Valid {
currentLevel = user.CurrentLevel.String
}
if req.CurrentLevel != nil && *req.CurrentLevel != "" {
currentLevel = *req.CurrentLevel
}
// Check if new username already exists (if changed)
if username != user.Username {
existingUser, err := h.userService.GetUserByUsername(c.Request.Context(), username)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking existing username", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check username uniqueness"))
return
}
if existingUser != nil {
HandleAppError(c, contextutils.ErrRecordExists)
return
}
}
// Check if new email already exists (if changed)
if email != "" && user.Email.Valid && email != user.Email.String {
existingUser, err := h.userService.GetUserByEmail(c.Request.Context(), email)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking existing email", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check email uniqueness"))
return
}
if existingUser != nil {
HandleAppError(c, contextutils.ErrRecordExists)
return
}
}
// Update user profile
err = h.userService.UpdateUserProfile(c.Request.Context(), userID, username, email, timezone)
if err != nil {
h.logger.Error(c.Request.Context(), "Error updating user profile", err, nil)
// Check if the error is due to user not found
if errors.Is(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to update user profile"))
return
}
// Handle AI settings update if provided
needsAIUpdate := req.AiEnabled != nil || (req.AiProvider != nil && *req.AiProvider != "") || (req.AiModel != nil && *req.AiModel != "") || (req.ApiKey != nil && *req.ApiKey != "")
if needsAIUpdate {
// Prepare AI settings
aiSettings := &models.UserSettings{
Language: preferredLanguage,
Level: currentLevel,
AIEnabled: req.AiEnabled != nil && *req.AiEnabled,
}
// Set AI provider and model
if req.AiProvider != nil && *req.AiProvider != "" {
aiSettings.AIProvider = *req.AiProvider
} else if user.AIProvider.Valid {
aiSettings.AIProvider = user.AIProvider.String
}
if req.AiModel != nil && *req.AiModel != "" {
aiSettings.AIModel = *req.AiModel
} else if user.AIModel.Valid {
aiSettings.AIModel = user.AIModel.String
}
// Set API key if provided
if req.ApiKey != nil && *req.ApiKey != "" {
aiSettings.AIAPIKey = *req.ApiKey
}
// Update AI settings
err = h.userService.UpdateUserSettings(c.Request.Context(), userID, aiSettings)
if err != nil {
h.logger.Error(c.Request.Context(), "Error updating user AI settings", err, nil)
// Check if the error is due to user not found
if errors.Is(err, contextutils.ErrRecordNotFound) {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
HandleAppError(c, contextutils.WrapError(err, "failed to update AI settings"))
return
}
}
// Handle role updates if provided
if req.SelectedRoles != nil {
// Get current user roles
currentRoles, err := h.userService.GetUserRoles(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error getting current user roles", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to get current user roles"))
return
}
// Get all available roles
allRoles, err := h.userService.GetAllRoles(c.Request.Context())
if err != nil {
h.logger.Error(c.Request.Context(), "Error getting all roles", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to get available roles"))
return
}
// Create maps for efficient lookup
currentRoleNames := make(map[string]bool)
for _, role := range currentRoles {
currentRoleNames[role.Name] = true
}
requestedRoleNames := make(map[string]bool)
for _, roleName := range *req.SelectedRoles {
requestedRoleNames[roleName] = true
}
// Find roles to add and remove
for _, roleName := range *req.SelectedRoles {
if !currentRoleNames[roleName] {
// Find role by name
var roleToAdd *models.Role
for _, role := range allRoles {
if role.Name == roleName {
roleToAdd = &role
break
}
}
if roleToAdd != nil {
err = h.userService.AssignRole(c.Request.Context(), userID, roleToAdd.ID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error assigning role to user", err, map[string]interface{}{
"user_id": userID,
"role_id": roleToAdd.ID,
"role_name": roleName,
})
HandleAppError(c, contextutils.WrapError(err, "failed to assign role"))
return
}
}
}
}
// Remove roles that are no longer selected
for _, role := range currentRoles {
if !requestedRoleNames[role.Name] {
err = h.userService.RemoveRole(c.Request.Context(), userID, role.ID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error removing role from user", err, map[string]interface{}{
"user_id": userID,
"role_id": role.ID,
"role_name": role.Name,
})
HandleAppError(c, contextutils.WrapError(err, "failed to remove role"))
return
}
}
}
}
// Get updated user
updatedUser, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving updated user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to retrieve updated user"))
return
}
c.JSON(http.StatusOK, gin.H{
"message": "User updated successfully",
"user": h.convertUserToProfileResponse(c.Request.Context(), updatedUser),
})
}
// DeleteUser handles DELETE /userz/:id - delete user (admin only)
func (h *UserAdminHandler) DeleteUser(c *gin.Context) {
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "database error"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Delete user
err = h.userService.DeleteUser(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error deleting user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to delete user"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "User deleted successfully"})
}
// ResetUserPassword handles POST /userz/:id/reset-password - reset user password (admin only)
func (h *UserAdminHandler) ResetUserPassword(c *gin.Context) {
userIDStr := c.Param("id")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving user", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "database error"))
return
}
if user == nil {
h.logger.Warn(c.Request.Context(), "User not found for password reset", map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
var req PasswordResetRequest
if err := c.ShouldBindJSON(&req); err != nil {
h.logger.Error(c.Request.Context(), "Invalid request data for password reset", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request data",
"",
err,
))
return
}
// Validate password
if req.NewPassword == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Update password
err = h.userService.UpdateUserPassword(c.Request.Context(), userID, req.NewPassword)
if err != nil {
h.logger.Error(c.Request.Context(), "Error updating user password", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to update password"))
return
}
h.logger.Info(c.Request.Context(), "Password reset successful", map[string]interface{}{"user_id": userID, "username": user.Username})
c.JSON(http.StatusOK, gin.H{"message": "Password reset successfully"})
}
// UpdateCurrentUserProfile handles PUT /userz/profile - update current user profile
func (h *UserAdminHandler) UpdateCurrentUserProfile(c *gin.Context) {
// Get user ID from context/session
userID, err := GetCurrentUserID(c)
if err != nil {
HandleAppError(c, contextutils.ErrUnauthorized)
return
}
var req UserUpdateRequest
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request data",
"",
err,
))
return
}
// Validate timezone if provided
if req.Timezone != nil && *req.Timezone != "" && !h.isValidTimezone(*req.Timezone) {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Email validation is handled automatically by openapi_types.Email
// Get current user
user, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "database error"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Check authorization (self-only for this endpoint)
if err := RequireSelfOrAdmin(c.Request.Context(), h.userService, userID, userID); err != nil {
if contextutils.IsError(err, contextutils.ErrForbidden) {
HandleAppError(c, contextutils.ErrForbidden)
return
}
h.logger.Error(c.Request.Context(), "Error checking authorization", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check authorization"))
return
}
// Use existing values if not provided in request
username := user.Username
if req.Username != nil && *req.Username != "" {
username = *req.Username
}
email := ""
if user.Email.Valid {
email = user.Email.String
}
if req.Email != nil {
email = string(*req.Email)
}
timezone := ""
if user.Timezone.Valid {
timezone = user.Timezone.String
}
if req.Timezone != nil && *req.Timezone != "" {
timezone = *req.Timezone
}
// Check if new username already exists (if changed)
if username != user.Username {
existingUser, err := h.userService.GetUserByUsername(c.Request.Context(), username)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking existing username", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check username uniqueness"))
return
}
if existingUser != nil {
HandleAppError(c, contextutils.ErrRecordExists)
return
}
}
// Check if new email already exists (if changed)
if email != "" && user.Email.Valid && email != user.Email.String {
existingUser, err := h.userService.GetUserByEmail(c.Request.Context(), email)
if err != nil {
h.logger.Error(c.Request.Context(), "Error checking existing email", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to check email uniqueness"))
return
}
if existingUser != nil {
HandleAppError(c, contextutils.ErrRecordExists)
return
}
}
// Use existing AI values if not provided in request
preferredLanguage := ""
if user.PreferredLanguage.Valid {
preferredLanguage = user.PreferredLanguage.String
}
if req.PreferredLanguage != nil && *req.PreferredLanguage != "" {
preferredLanguage = *req.PreferredLanguage
}
currentLevel := ""
if user.CurrentLevel.Valid {
currentLevel = user.CurrentLevel.String
}
if req.CurrentLevel != nil && *req.CurrentLevel != "" {
currentLevel = *req.CurrentLevel
}
// Update user profile
err = h.userService.UpdateUserProfile(c.Request.Context(), userID, username, email, timezone)
if err != nil {
h.logger.Error(c.Request.Context(), "Error updating user profile", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to update user profile"))
return
}
// Handle AI settings update if provided
needsAIUpdate := req.AiEnabled != nil || (req.AiProvider != nil && *req.AiProvider != "") || (req.AiModel != nil && *req.AiModel != "") || (req.PreferredLanguage != nil && *req.PreferredLanguage != "") || (req.CurrentLevel != nil && *req.CurrentLevel != "") || (req.ApiKey != nil && *req.ApiKey != "")
if needsAIUpdate {
aiSettings := &models.UserSettings{
Language: preferredLanguage,
Level: currentLevel,
AIEnabled: req.AiEnabled != nil && *req.AiEnabled,
}
if req.AiProvider != nil && *req.AiProvider != "" {
aiSettings.AIProvider = *req.AiProvider
} else if user.AIProvider.Valid {
aiSettings.AIProvider = user.AIProvider.String
}
if req.AiModel != nil && *req.AiModel != "" {
aiSettings.AIModel = *req.AiModel
} else if user.AIModel.Valid {
aiSettings.AIModel = user.AIModel.String
}
if req.ApiKey != nil && *req.ApiKey != "" {
aiSettings.AIAPIKey = *req.ApiKey
}
err = h.userService.UpdateUserSettings(c.Request.Context(), userID, aiSettings)
if err != nil {
h.logger.Error(c.Request.Context(), "Error updating user AI settings", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to update AI settings"))
return
}
}
// Get updated user
updatedUser, err := h.userService.GetUserByID(c.Request.Context(), userID)
if err != nil {
h.logger.Error(c.Request.Context(), "Error retrieving updated user", err, nil)
HandleAppError(c, contextutils.WrapError(err, "failed to retrieve updated profile"))
return
}
c.JSON(http.StatusOK, gin.H{
"message": "Profile updated successfully",
"user": h.convertUserToProfileResponse(c.Request.Context(), updatedUser),
})
}
// isUserPaused checks if a user is paused by checking the worker_settings table
func (h *UserAdminHandler) isUserPaused(ctx context.Context, userID int) bool {
query := `SELECT setting_value FROM worker_settings WHERE setting_key = $1`
var value string
settingKey := fmt.Sprintf("user_pause_%d", userID)
err := h.userService.GetDB().QueryRowContext(ctx, query, settingKey).Scan(&value)
if err != nil {
// If no setting exists, user is not paused
if errors.Is(err, sql.ErrNoRows) {
return false
}
// Log error but don't fail - default to not paused
h.logger.Warn(ctx, "Failed to check user pause status", map[string]interface{}{
"user_id": userID,
"error": err.Error(),
})
return false
}
return value == "true"
}
// Helper functions
// convertUserToProfileResponse converts a User model to ProfileResponse
func (h *UserAdminHandler) convertUserToProfileResponse(ctx context.Context, user *models.User) ProfileResponse {
// Get user roles
roles, err := h.userService.GetUserRoles(ctx, user.ID)
if err != nil {
// Log error but don't fail the response
h.logger.Warn(ctx, "Failed to get user roles", map[string]interface{}{
"user_id": user.ID,
"error": err.Error(),
})
roles = []models.Role{}
}
return ProfileResponse{
ID: user.ID,
Username: user.Username,
Email: nullStringToPointer(user.Email),
Timezone: nullStringToPointer(user.Timezone),
LastActive: nullTimeToPointer(user.LastActive),
PreferredLanguage: nullStringToPointer(user.PreferredLanguage),
CurrentLevel: nullStringToPointer(user.CurrentLevel),
CreatedAt: user.CreatedAt,
UpdatedAt: user.UpdatedAt,
AIEnabled: user.AIEnabled.Valid && user.AIEnabled.Bool,
AIProvider: nullStringToPointer(user.AIProvider),
AIModel: nullStringToPointer(user.AIModel),
Roles: roles,
IsPaused: h.isUserPaused(ctx, user.ID),
}
}
// isValidTimezone checks if a timezone string is valid
func (h *UserAdminHandler) isValidTimezone(tz string) bool {
// Common timezone validation - check if it can be loaded
_, err := time.LoadLocation(tz)
if err != nil {
// Also allow UTC as fallback
return strings.ToUpper(tz) == "UTC"
}
return true
}
// Helper function to convert sql.NullString to *string (if not already available)
func nullStringToPointer(ns sql.NullString) *string {
if ns.Valid {
return &ns.String
}
return nil
}
// Helper function to convert sql.NullTime to *time.Time (if not already available)
func nullTimeToPointer(nt sql.NullTime) *time.Time {
if nt.Valid {
return &nt.Time
}
return nil
}
package handlers
import (
"errors"
"fmt"
"html/template"
"net/http"
"strconv"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/observability"
"quizapp/internal/services"
contextutils "quizapp/internal/utils"
"quizapp/internal/worker"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/otel/attribute"
)
// WorkerAdminHandler handles worker administration endpoints
type WorkerAdminHandler struct {
userService services.UserServiceInterface
questionService services.QuestionServiceInterface
aiService services.AIServiceInterface
config *config.Config
worker *worker.Worker
workerService services.WorkerServiceInterface
templates *template.Template
learningService services.LearningServiceInterface
dailyQuestionService services.DailyQuestionServiceInterface
logger *observability.Logger
}
// NewWorkerAdminHandlerWithLogger creates a new WorkerAdminHandler
func NewWorkerAdminHandlerWithLogger(
userService services.UserServiceInterface,
questionService services.QuestionServiceInterface,
aiService services.AIServiceInterface,
cfg *config.Config,
worker *worker.Worker,
workerService services.WorkerServiceInterface,
learningService services.LearningServiceInterface,
dailyQuestionService services.DailyQuestionServiceInterface,
logger *observability.Logger,
) *WorkerAdminHandler {
return &WorkerAdminHandler{
userService: userService,
questionService: questionService,
aiService: aiService,
config: cfg,
worker: worker,
workerService: workerService,
templates: nil,
learningService: learningService,
dailyQuestionService: dailyQuestionService,
logger: logger,
}
}
// GetWorkerDetails returns detailed worker information
func (h *WorkerAdminHandler) GetWorkerDetails(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_worker_details")
defer span.End()
// Get worker status from local instance if available
var localStatus worker.Status
var localHistory []worker.RunRecord
if h.worker != nil {
localStatus = h.worker.GetStatus()
localHistory = h.worker.GetHistory()
}
// Get global pause status
globalPaused, err := h.workerService.IsGlobalPaused(ctx)
if err != nil {
// Log the error but continue with default value
h.logger.Warn(ctx, "Failed to get global pause status", map[string]interface{}{"error": err.Error()})
globalPaused = false
}
response := gin.H{
"status": localStatus,
"history": localHistory,
"global_paused": globalPaused,
}
c.JSON(http.StatusOK, response)
}
// GetActivityLogs returns recent activity logs from the worker
func (h *WorkerAdminHandler) GetActivityLogs(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_activity_logs")
defer span.End()
if h.worker == nil {
HandleAppError(c, contextutils.ErrServiceUnavailable)
return
}
logs := h.worker.GetActivityLogs()
c.JSON(http.StatusOK, gin.H{"logs": logs})
}
// PauseWorker pauses the worker globally
func (h *WorkerAdminHandler) PauseWorker(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "pause_worker")
defer span.End()
if err := h.workerService.SetGlobalPause(ctx, true); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to pause worker globally"))
return
}
// Also pause the local worker instance if available
if h.worker != nil {
h.worker.Pause(ctx)
}
c.JSON(http.StatusOK, gin.H{"message": "Worker paused globally"})
}
// ResumeWorker resumes the worker globally
func (h *WorkerAdminHandler) ResumeWorker(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "resume_worker")
defer span.End()
if err := h.workerService.SetGlobalPause(ctx, false); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to resume worker globally"))
return
}
// Also resume the local worker instance if available
if h.worker != nil {
h.worker.Resume(ctx)
}
c.JSON(http.StatusOK, gin.H{"message": "Worker resumed globally"})
}
// GetWorkerStatus returns current worker status
func (h *WorkerAdminHandler) GetWorkerStatus(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_worker_status")
defer span.End()
instance := c.DefaultQuery("instance", "default")
status, err := h.workerService.GetWorkerStatus(ctx, instance)
if err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to get worker status"))
return
}
c.JSON(http.StatusOK, status)
}
// TriggerWorkerRun triggers a manual worker run
func (h *WorkerAdminHandler) TriggerWorkerRun(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "trigger_worker_run")
defer span.End()
if h.worker != nil {
h.worker.TriggerManualRun()
c.JSON(http.StatusOK, gin.H{"message": "Worker run triggered"})
} else {
HandleAppError(c, contextutils.ErrServiceUnavailable)
}
}
// PauseWorkerUser pauses question generation for a specific user
func (h *WorkerAdminHandler) PauseWorkerUser(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "pause_user")
defer span.End()
var req struct {
UserID int `json:"user_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request",
"",
err,
))
return
}
if err := h.workerService.SetUserPause(ctx, req.UserID, true); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to pause user"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "User paused successfully"})
}
// ResumeWorkerUser resumes question generation for a specific user
func (h *WorkerAdminHandler) ResumeWorkerUser(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "resume_user")
defer span.End()
var req struct {
UserID int `json:"user_id" binding:"required"`
}
if err := c.ShouldBindJSON(&req); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request",
"",
err,
))
return
}
if err := h.workerService.SetUserPause(ctx, req.UserID, false); err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to resume user"))
return
}
c.JSON(http.StatusOK, gin.H{"message": "User resumed successfully"})
}
// GetWorkerUsers returns basic user list for worker controls
func (h *WorkerAdminHandler) GetWorkerUsers(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_worker_users")
defer span.End()
users, err := h.userService.GetAllUsers(ctx)
if err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to get users"))
return
}
// Add pause status for each user
var userList []gin.H
for _, user := range users {
isPaused, _ := h.workerService.IsUserPaused(ctx, user.ID)
userList = append(userList, gin.H{
"id": user.ID,
"username": user.Username,
"is_paused": isPaused,
})
}
c.JSON(http.StatusOK, gin.H{"users": userList})
}
// GetSystemHealth returns comprehensive system health
func (h *WorkerAdminHandler) GetSystemHealth(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_system_health")
defer span.End()
health, err := h.workerService.GetWorkerHealth(ctx)
if err != nil {
HandleAppError(c, contextutils.WrapError(err, "failed to get system health"))
return
}
c.JSON(http.StatusOK, health)
}
// GetAIConcurrencyStats returns AI service concurrency metrics from the worker
func (h *WorkerAdminHandler) GetAIConcurrencyStats(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_ai_concurrency_stats")
defer span.End()
if h.aiService == nil {
HandleAppError(c, contextutils.ErrAIProviderUnavailable)
return
}
stats := h.aiService.GetConcurrencyStats()
c.JSON(http.StatusOK, gin.H{
"ai_concurrency": stats,
})
}
// GetPriorityAnalytics returns priority system analytics
func (h *WorkerAdminHandler) GetPriorityAnalytics(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_priority_analytics")
defer span.End()
// Get priority score distribution
distribution, err := h.learningService.GetPriorityScoreDistribution(ctx)
if err != nil {
h.logger.Error(ctx, "Error getting priority score distribution", err, map[string]interface{}{})
distribution = map[string]interface{}{
"high": 0,
"medium": 0,
"low": 0,
"average": 0.0,
}
}
// Get high priority questions
highPriorityQuestions, err := h.learningService.GetHighPriorityQuestions(ctx, 5)
if err != nil {
h.logger.Error(ctx, "Error getting high priority questions", err, map[string]interface{}{})
highPriorityQuestions = []map[string]interface{}{}
}
response := gin.H{
"distribution": distribution,
"highPriorityQuestions": highPriorityQuestions,
}
c.JSON(http.StatusOK, response)
}
// GetUserPriorityAnalytics returns priority analytics for a specific user
func (h *WorkerAdminHandler) GetUserPriorityAnalytics(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_user_priority_analytics")
defer span.End()
userIDStr := c.Param("userID")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Verify user exists
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil || user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Get user-specific priority score distribution
distribution, err := h.learningService.GetUserPriorityScoreDistribution(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Error getting user priority score distribution", err, map[string]interface{}{})
distribution = map[string]interface{}{
"high": 0,
"medium": 0,
"low": 0,
"average": 0.0,
}
}
// Get user's high priority questions
highPriorityQuestions, err := h.learningService.GetUserHighPriorityQuestions(ctx, userID, 10)
if err != nil {
h.logger.Error(ctx, "Error getting user high priority questions", err, map[string]interface{}{})
highPriorityQuestions = []map[string]interface{}{}
}
// Get user's weak areas
weakAreas, err := h.learningService.GetUserWeakAreas(ctx, userID, 5)
if err != nil {
h.logger.Error(ctx, "Error getting user weak areas", err, map[string]interface{}{})
weakAreas = []map[string]interface{}{}
}
// Get user's learning preferences
preferences, err := h.learningService.GetUserLearningPreferences(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Error getting user learning preferences", err, map[string]interface{}{})
preferences = nil
}
response := gin.H{
"user": gin.H{
"id": user.ID,
"username": user.Username,
},
"distribution": distribution,
"highPriorityQuestions": highPriorityQuestions,
"weakAreas": weakAreas,
"learningPreferences": preferences,
}
c.JSON(http.StatusOK, response)
}
// GetUserPerformanceAnalytics returns user performance analytics
func (h *WorkerAdminHandler) GetUserPerformanceAnalytics(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_user_performance_analytics")
defer span.End()
// Get weak areas by topic
weakAreas, err := h.learningService.GetWeakAreasByTopic(ctx, 5)
if err != nil {
h.logger.Error(ctx, "Error getting weak areas", err, map[string]interface{}{})
weakAreas = []map[string]interface{}{}
}
// Get learning preferences usage
learningPreferences, err := h.learningService.GetLearningPreferencesUsage(ctx)
if err != nil {
h.logger.Error(ctx, "Error getting learning preferences usage", err, map[string]interface{}{})
learningPreferences = map[string]interface{}{}
}
response := gin.H{
"weakAreas": weakAreas,
"learningPreferences": learningPreferences,
}
c.JSON(http.StatusOK, response)
}
// GetGenerationIntelligence returns question generation intelligence
func (h *WorkerAdminHandler) GetGenerationIntelligence(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_generation_intelligence")
defer span.End()
// Get gap analysis
gapAnalysis, err := h.learningService.GetQuestionTypeGaps(ctx)
if err != nil {
h.logger.Error(ctx, "Error getting gap analysis", err, map[string]interface{}{})
gapAnalysis = []map[string]interface{}{}
}
// Get generation suggestions
generationSuggestions, err := h.learningService.GetGenerationSuggestions(ctx)
if err != nil {
h.logger.Error(ctx, "Error getting generation suggestions", err, map[string]interface{}{})
generationSuggestions = []map[string]interface{}{}
}
// Ensure we always return arrays, not nil
if gapAnalysis == nil {
gapAnalysis = []map[string]interface{}{}
}
if generationSuggestions == nil {
generationSuggestions = []map[string]interface{}{}
}
response := gin.H{
"gapAnalysis": gapAnalysis,
"generationSuggestions": generationSuggestions,
}
c.JSON(http.StatusOK, response)
}
// GetSystemHealthAnalytics returns system health analytics
func (h *WorkerAdminHandler) GetSystemHealthAnalytics(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_system_health_analytics")
defer span.End()
// Get performance metrics
performance, err := h.learningService.GetPrioritySystemPerformance(ctx)
if err != nil {
h.logger.Error(ctx, "Error getting performance metrics", err, map[string]interface{}{})
performance = map[string]interface{}{}
}
// Get background jobs status
backgroundJobs, err := h.learningService.GetBackgroundJobsStatus(ctx)
if err != nil {
h.logger.Error(ctx, "Error getting background jobs status", err, map[string]interface{}{})
backgroundJobs = map[string]interface{}{}
}
response := gin.H{
"performance": performance,
"backgroundJobs": backgroundJobs,
}
c.JSON(http.StatusOK, response)
}
// GetUserComparisonAnalytics returns comparison analytics between users
func (h *WorkerAdminHandler) GetUserComparisonAnalytics(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_user_comparison_analytics")
defer span.End()
userIDsParam := c.Query("user_ids")
if userIDsParam == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Split comma-separated user IDs
userIDsStr := strings.Split(userIDsParam, ",")
if len(userIDsStr) == 0 {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
var userIDs []int
for _, idStr := range userIDsStr {
idStr = strings.TrimSpace(idStr) // Remove whitespace
if idStr == "" {
continue
}
id, err := strconv.Atoi(idStr)
if err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidFormat,
contextutils.SeverityWarn,
"Invalid user ID",
idStr,
err,
))
return
}
userIDs = append(userIDs, id)
}
if len(userIDs) == 0 {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Get comparison data for each user
var comparisonData []gin.H
for _, userID := range userIDs {
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
continue // Skip invalid users
}
distribution, _ := h.learningService.GetUserPriorityScoreDistribution(ctx, userID)
weakAreas, _ := h.learningService.GetUserWeakAreas(ctx, userID, 3)
userData := gin.H{
"user": gin.H{
"id": user.ID,
"username": user.Username,
},
"distribution": distribution,
"weakAreas": weakAreas,
}
comparisonData = append(comparisonData, userData)
}
c.JSON(http.StatusOK, gin.H{"comparison": comparisonData})
}
// GetConfigz returns the merged config as pretty-printed JSON
func (h *WorkerAdminHandler) GetConfigz(c *gin.Context) {
_, span := observability.TraceHandlerFunction(c.Request.Context(), "get_configz")
defer span.End()
c.IndentedJSON(http.StatusOK, h.config)
}
// GetNotificationStats returns comprehensive notification statistics
func (h *WorkerAdminHandler) GetNotificationStats(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_notification_stats")
defer span.End()
// Get notification statistics from database
stats, err := h.workerService.GetNotificationStats(ctx)
if err != nil {
h.logger.Error(ctx, "Failed to get notification stats", err, nil)
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get notification statistics",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, stats)
}
// GetNotificationErrors returns paginated notification errors
func (h *WorkerAdminHandler) GetNotificationErrors(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_notification_errors")
defer span.End()
// Parse pagination and filters
page, pageSize := ParsePagination(c, 1, 20, 100)
f := ParseFilters(c, "error_type", "notification_type", "resolved")
errorType := f["error_type"]
notificationType := f["notification_type"]
resolved := f["resolved"]
// Get notification errors from database
errors, pagination, stats, err := h.workerService.GetNotificationErrors(ctx, page, pageSize, errorType, notificationType, resolved)
if err != nil {
h.logger.Error(ctx, "Failed to get notification errors", err, nil)
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get notification errors",
"details": err.Error(),
})
return
}
WritePaginated(c, "errors", errors, pagination, gin.H{"stats": stats})
}
// GetSentNotifications returns paginated sent notifications
func (h *WorkerAdminHandler) GetSentNotifications(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "get_sent_notifications")
defer span.End()
// Parse pagination and filters
page, pageSize := ParsePagination(c, 1, 20, 100)
f := ParseFilters(c, "notification_type", "status", "sent_after", "sent_before")
notificationType := f["notification_type"]
status := f["status"]
sentAfter := f["sent_after"]
sentBefore := f["sent_before"]
// Get sent notifications from database
notifications, pagination, stats, err := h.workerService.GetSentNotifications(ctx, page, pageSize, notificationType, status, sentAfter, sentBefore)
if err != nil {
h.logger.Error(ctx, "Failed to get sent notifications", err, nil)
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get sent notifications",
"details": err.Error(),
})
return
}
WritePaginated(c, "notifications", notifications, pagination, gin.H{"stats": stats})
}
// CreateTestSentNotification creates a test sent notification for testing
func (h *WorkerAdminHandler) CreateTestSentNotification(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "create_test_sent_notification")
defer span.End()
// Parse request body
var request struct {
UserID int `json:"user_id" binding:"required"`
NotificationType string `json:"notification_type" binding:"required"`
Subject string `json:"subject" binding:"required"`
TemplateName string `json:"template_name" binding:"required"`
Status string `json:"status" binding:"required"`
ErrorMessage string `json:"error_message"`
}
if err := c.ShouldBindJSON(&request); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Create test notification
err := h.workerService.CreateTestSentNotification(ctx, request.UserID, request.NotificationType, request.Subject, request.TemplateName, request.Status, request.ErrorMessage)
if err != nil {
h.logger.Error(ctx, "Failed to create test sent notification", err, map[string]interface{}{
"user_id": request.UserID,
"notification_type": request.NotificationType,
})
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to create test sent notification",
"details": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{"message": "Test sent notification created successfully"})
}
// ForceSendNotification forces sending a notification to a user, bypassing normal checks
func (h *WorkerAdminHandler) ForceSendNotification(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "force_send_notification")
defer span.End()
// Parse request body
var request struct {
Username string `json:"username" binding:"required"`
}
if err := c.ShouldBindJSON(&request); err != nil {
HandleAppError(c, contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInvalidInput,
contextutils.SeverityWarn,
"Invalid request body",
"",
err,
))
return
}
// Get user by username
user, err := h.userService.GetUserByUsername(ctx, request.Username)
if err != nil {
h.logger.Error(ctx, "Failed to get user by username", err, map[string]interface{}{
"username": request.Username,
})
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get user",
"details": err.Error(),
})
return
}
if user == nil {
HandleAppError(c, contextutils.NewAppError(
contextutils.ErrorCodeRecordNotFound,
contextutils.SeverityInfo,
fmt.Sprintf("User '%s' not found", request.Username),
"",
))
return
}
// Check if user has email address
if !user.Email.Valid || user.Email.String == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
// Get user's learning preferences to check daily reminder setting
prefs, err := h.learningService.GetUserLearningPreferences(ctx, user.ID)
if err != nil {
h.logger.Error(ctx, "Failed to get user learning preferences", err, map[string]interface{}{
"user_id": user.ID,
})
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get user preferences",
"details": err.Error(),
})
return
}
// Check if daily reminders are enabled for this user
if prefs == nil || !prefs.DailyReminderEnabled {
HandleAppError(c, contextutils.NewAppError(contextutils.ErrorCodeInvalidInput, contextutils.SeverityWarn, "User has daily reminders disabled", ""))
return
}
// Force send the daily reminder (bypassing time and date checks)
subject := "Time for your daily quiz! ð"
status := "sent"
errorMsg := ""
// Get email service from worker
emailService := h.worker.GetEmailService()
if emailService == nil {
HandleAppError(c, contextutils.ErrServiceUnavailable)
return
}
// Send the email
if err := emailService.SendDailyReminder(ctx, user); err != nil {
h.logger.Error(ctx, "Failed to send forced daily reminder", err, map[string]interface{}{
"user_id": user.ID,
"email": user.Email.String,
})
HandleAppError(c, contextutils.WrapError(err, "failed to send notification"))
return
}
// Record the sent notification in the database
if err := emailService.RecordSentNotification(ctx, user.ID, "daily_reminder", subject, "daily_reminder", status, errorMsg); err != nil {
h.logger.Error(ctx, "Failed to record sent notification", err, map[string]interface{}{
"user_id": user.ID,
})
// Don't fail the request if recording fails
}
// Update the last reminder sent timestamp for this user
if err := h.learningService.UpdateLastDailyReminderSent(ctx, user.ID); err != nil {
h.logger.Error(ctx, "Failed to update last daily reminder sent timestamp", err, map[string]interface{}{
"user_id": user.ID,
})
// Don't fail the request if timestamp update fails
}
h.logger.Info(ctx, "Forced notification sent successfully", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"email": user.Email.String,
})
c.JSON(http.StatusOK, gin.H{
"message": "Notification sent successfully",
"user": gin.H{
"id": user.ID,
"username": user.Username,
"email": user.Email.String,
},
"notification": gin.H{
"type": "daily_reminder",
"subject": subject,
"status": status,
},
})
}
// GetUserDailyQuestions returns daily questions for a specific user and date (admin only)
func (h *WorkerAdminHandler) GetUserDailyQuestions(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "admin_get_user_daily_questions")
defer span.End()
// Parse user ID
userIDStr := c.Param("userId")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for daily questions", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Parse date
dateStr := c.Param("date")
if dateStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
date, err := time.Parse("2006-01-02", dateStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
)
// Get daily questions for the user and date
questions, err := h.dailyQuestionService.GetDailyQuestions(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to get user daily questions", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
})
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to get daily questions",
"details": err.Error(),
})
return
}
// Convert to API format (similar to the daily question handler)
apiQuestions := make([]gin.H, len(questions))
for i, q := range questions {
var completedAt *time.Time
if q.CompletedAt.Valid {
completedAt = &q.CompletedAt.Time
}
apiQuestions[i] = gin.H{
"id": q.ID,
"user_id": q.UserID,
"question_id": q.QuestionID,
"assignment_date": q.AssignmentDate,
"is_completed": q.IsCompleted,
"completed_at": completedAt,
"created_at": q.CreatedAt,
// Per-user stats for admin UI
"user_shown_count": q.DailyShownCount,
"user_total_responses": q.UserTotalResponses,
"user_correct_count": q.UserCorrectCount,
"user_incorrect_count": q.UserIncorrectCount,
"question": gin.H{
"id": q.Question.ID,
"type": q.Question.Type,
"language": q.Question.Language,
"level": q.Question.Level,
"difficulty_score": q.Question.DifficultyScore,
"content": q.Question.Content,
"correct_answer": q.Question.CorrectAnswer,
"explanation": q.Question.Explanation,
"created_at": q.Question.CreatedAt,
"status": q.Question.Status,
"topic_category": q.Question.TopicCategory,
"grammar_focus": q.Question.GrammarFocus,
"vocabulary_domain": q.Question.VocabularyDomain,
"scenario": q.Question.Scenario,
"style_modifier": q.Question.StyleModifier,
"difficulty_modifier": q.Question.DifficultyModifier,
"time_context": q.Question.TimeContext,
},
}
}
c.JSON(http.StatusOK, gin.H{"questions": apiQuestions})
}
// RegenerateUserDailyQuestions clears and regenerates daily questions for a specific user and date (admin only)
func (h *WorkerAdminHandler) RegenerateUserDailyQuestions(c *gin.Context) {
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "admin_regenerate_user_daily_questions")
defer span.End()
// Parse user ID
userIDStr := c.Param("userId")
userID, err := strconv.Atoi(userIDStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Check if user exists
user, err := h.userService.GetUserByID(ctx, userID)
if err != nil {
h.logger.Error(ctx, "Failed to get user for daily questions regeneration", err, map[string]interface{}{"user_id": userID})
HandleAppError(c, contextutils.WrapError(err, "failed to get user"))
return
}
if user == nil {
HandleAppError(c, contextutils.ErrRecordNotFound)
return
}
// Parse date
dateStr := c.Param("date")
if dateStr == "" {
HandleAppError(c, contextutils.ErrMissingRequired)
return
}
date, err := time.Parse("2006-01-02", dateStr)
if err != nil {
HandleAppError(c, contextutils.ErrInvalidFormat)
return
}
// Add span attributes for observability
span.SetAttributes(
observability.AttributeUserID(userID),
attribute.String("date", dateStr),
)
// For regeneration, we need to manually clear existing assignments and create new ones
// Since the daily question service doesn't expose a direct way to clear assignments,
// we'll use the worker service which should have database access for this admin operation
// Check if worker service is available
if h.workerService == nil {
HandleAppError(c, contextutils.ErrServiceUnavailable)
return
}
// Use the new RegenerateDailyQuestions method which clears existing assignments and creates new ones
err = h.dailyQuestionService.RegenerateDailyQuestions(ctx, userID, date)
if err != nil {
h.logger.Error(ctx, "Failed to regenerate daily questions", err, map[string]interface{}{
"user_id": userID,
"date": dateStr,
})
// If there are no questions available for assignment, prefer the structured error from the service
var nqErr *services.NoQuestionsAvailableError
if errors.As(err, &nqErr) {
c.JSON(http.StatusBadRequest, gin.H{
"error": "Failed to regenerate daily questions",
"details": err.Error(),
"user": gin.H{"id": user.ID, "username": user.Username, "language": nqErr.Language, "level": nqErr.Level},
"candidate_count": nqErr.CandidateCount,
"candidate_ids": nqErr.CandidateIDs,
"total_matching_questions": nqErr.TotalMatching,
})
return
}
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to regenerate daily questions",
"details": err.Error(),
})
return
}
h.logger.Info(ctx, "Daily questions regenerated successfully", map[string]interface{}{
"user_id": userID,
"date": dateStr,
})
c.JSON(http.StatusOK, gin.H{"success": true, "message": "Daily questions regenerated successfully. All existing assignments have been cleared and new questions assigned."})
}
// Package middleware provides authentication and authorization middleware for the Gin web framework.
package middleware
import (
"context"
"net/http"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
)
// Session keys for storing user information
const (
// UserIDKey is the key used to store user ID in session
UserIDKey = "user_id"
// UsernameKey is the key used to store username in session
UsernameKey = "username"
)
// RequireAuth returns a middleware that requires authentication
func RequireAuth() gin.HandlerFunc {
return func(c *gin.Context) {
session := sessions.Default(c)
userID := session.Get(UserIDKey)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
// Validate user_id is an integer
userIDInt, ok := userID.(int)
if !ok {
// Try to convert from float64 (JSON numbers are often stored as float64)
if userIDFloat, ok := userID.(float64); ok {
userIDInt = int(userIDFloat)
} else {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
}
// Validate username is a string and not empty
username := session.Get(UsernameKey)
if username == nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
usernameStr, ok := username.(string)
if !ok || usernameStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
// Store user info in context for handlers to use
c.Set(UserIDKey, userIDInt)
c.Set(UsernameKey, usernameStr)
c.Next()
}
}
// RequireAdmin returns a middleware that requires authentication and admin role
func RequireAdmin(userService interface{}) gin.HandlerFunc {
// Type assertion to get the user service
us, ok := userService.(interface {
IsAdmin(ctx context.Context, userID int) (bool, error)
})
if !ok {
panic("userService must implement IsAdmin method")
}
return func(c *gin.Context) {
// First check authentication
session := sessions.Default(c)
userID := session.Get(UserIDKey)
if userID == nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
// Validate user_id is an integer
userIDInt, ok := userID.(int)
if !ok {
// Try to convert from float64 (JSON numbers are often stored as float64)
if userIDFloat, ok := userID.(float64); ok {
userIDInt = int(userIDFloat)
} else {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
}
// Validate username is a string and not empty
username := session.Get(UsernameKey)
if username == nil {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
usernameStr, ok := username.(string)
if !ok || usernameStr == "" {
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authentication required",
"code": "UNAUTHORIZED",
})
c.Abort()
return
}
// Check if user has admin role
isAdmin, err := us.IsAdmin(c.Request.Context(), userIDInt)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{
"error": "Failed to check admin status",
"code": "INTERNAL_ERROR",
})
c.Abort()
return
}
if !isAdmin {
c.JSON(http.StatusForbidden, gin.H{
"error": "Admin access required",
"code": "FORBIDDEN",
})
c.Abort()
return
}
// Store user info in context for handlers to use
c.Set(UserIDKey, userIDInt)
c.Set(UsernameKey, usernameStr)
c.Next()
}
}
package middleware
import (
"fmt"
"net/http"
"runtime/debug"
"time"
contextutils "quizapp/internal/utils"
"github.com/gin-gonic/gin"
)
// ErrorRecoveryConfig configures error recovery behavior
type ErrorRecoveryConfig struct {
// MaxRetries specifies the maximum number of retries for retryable errors
MaxRetries int
// RetryDelay specifies the base delay between retries
RetryDelay time.Duration
// MaxRetryDelay specifies the maximum delay between retries
MaxRetryDelay time.Duration
// EnableCircuitBreaker enables circuit breaker pattern
EnableCircuitBreaker bool
// CircuitBreakerThreshold specifies failure threshold for circuit breaker
CircuitBreakerThreshold int
// CircuitBreakerTimeout specifies how long to wait before retrying after circuit opens
CircuitBreakerTimeout time.Duration
}
// DefaultErrorRecoveryConfig returns a default error recovery configuration
func DefaultErrorRecoveryConfig() *ErrorRecoveryConfig {
return &ErrorRecoveryConfig{
MaxRetries: 3,
RetryDelay: 100 * time.Millisecond,
MaxRetryDelay: 5 * time.Second,
EnableCircuitBreaker: false,
CircuitBreakerThreshold: 5,
CircuitBreakerTimeout: 30 * time.Second,
}
}
// circuitBreakerState represents the state of a circuit breaker
type circuitBreakerState int
const (
circuitClosed circuitBreakerState = iota
circuitOpen
circuitHalfOpen
)
// circuitBreaker tracks failures and manages circuit state
type circuitBreaker struct {
state circuitBreakerState
failures int
lastFailure time.Time
config *ErrorRecoveryConfig
}
// newCircuitBreaker creates a new circuit breaker
func newCircuitBreaker(config *ErrorRecoveryConfig) *circuitBreaker {
return &circuitBreaker{
state: circuitClosed,
config: config,
}
}
// canExecute checks if the circuit breaker allows execution
func (cb *circuitBreaker) canExecute() bool {
switch cb.state {
case circuitClosed:
return true
case circuitOpen:
if time.Since(cb.lastFailure) > cb.config.CircuitBreakerTimeout {
cb.state = circuitHalfOpen
return true
}
return false
case circuitHalfOpen:
return true
default:
return false
}
}
// recordSuccess records a successful execution
func (cb *circuitBreaker) recordSuccess() {
cb.failures = 0
cb.state = circuitClosed
}
// recordFailure records a failed execution
func (cb *circuitBreaker) recordFailure() {
cb.failures++
cb.lastFailure = time.Now()
if cb.failures >= cb.config.CircuitBreakerThreshold {
cb.state = circuitOpen
}
}
// ErrorRecoveryMiddleware creates middleware for handling panics and retrying failed requests
func ErrorRecoveryMiddleware(logger interface{}, config *ErrorRecoveryConfig) gin.HandlerFunc {
if config == nil {
config = DefaultErrorRecoveryConfig()
}
// Create circuit breaker if enabled
var cb *circuitBreaker
if config.EnableCircuitBreaker {
cb = newCircuitBreaker(config)
}
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
// Log the panic with stack trace
stackTrace := string(debug.Stack())
fmt.Printf("Panic recovered: %v\nStack trace: %s\n", err, stackTrace)
// Convert panic value to error if needed
var panicErr error
if e, ok := err.(error); ok {
panicErr = e
} else {
panicErr = contextutils.WrapErrorf(nil, "panic: %v", err)
}
// Send error response
appErr := contextutils.NewAppErrorWithCause(
contextutils.ErrorCodeInternalError,
contextutils.SeverityFatal,
"Internal server error",
"A panic occurred while processing the request",
contextutils.WrapError(panicErr, "panic"),
)
// Add stack trace to error details in development
if gin.Mode() == gin.DebugMode {
appErr.Details = fmt.Sprintf("%s\nStack trace: %s", appErr.Details, stackTrace)
}
HandleAppError(c, appErr)
c.Abort()
}
}()
// Check circuit breaker
if cb != nil && !cb.canExecute() {
ServiceUnavailable(c, "Service temporarily unavailable due to high error rate")
c.Abort()
return
}
// Process request
c.Next()
// Record success/failure for circuit breaker
if cb != nil {
if c.Writer.Status() >= 500 {
cb.recordFailure()
} else if c.Writer.Status() < 500 && cb.state == circuitHalfOpen {
cb.recordSuccess()
}
}
// Retry logic for failed requests
if shouldRetry(c.Writer.Status(), c.Errors) {
retryWithBackoff(c, config, logger)
}
}
}
// shouldRetry determines if a request should be retried
func shouldRetry(statusCode int, errors []*gin.Error) bool {
// Only retry 5xx errors and certain 4xx errors
if statusCode >= 500 {
return true
}
// Retry on specific 4xx errors that might be transient
if statusCode == http.StatusRequestTimeout || statusCode == http.StatusTooManyRequests {
return true
}
// Check if there are errors that indicate retryable failures
for _, err := range errors {
if contextutils.IsRetryable(err) {
return true
}
}
return false
}
// retryWithBackoff attempts to retry the request with exponential backoff
func retryWithBackoff(c *gin.Context, config *ErrorRecoveryConfig, logger interface{}) {
// Only retry idempotent methods (GET, HEAD, OPTIONS, PUT, DELETE)
method := c.Request.Method
if method != http.MethodGet && method != http.MethodHead &&
method != http.MethodOptions && method != http.MethodPut &&
method != http.MethodDelete {
return
}
// Get the original handler
handlerName := c.HandlerName()
if handlerName == "" {
return
}
// Calculate retry delay with exponential backoff
delay := config.RetryDelay
for i := 0; i < config.MaxRetries; i++ {
time.Sleep(delay)
// Double the delay for next iteration (with max limit)
delay *= 2
if delay > config.MaxRetryDelay {
delay = config.MaxRetryDelay
}
// Log retry attempt
if logger != nil {
// This would be logged using the observability logger in real implementation
fmt.Printf("Retrying request %s %s (attempt %d/%d)\n",
method, c.Request.URL.Path, i+1, config.MaxRetries)
}
// Note: In a real implementation, we would need to recreate the request
// and re-execute it. This is a simplified version for demonstration.
// The actual retry logic would depend on the specific use case.
}
}
// HandleAppError handles any AppError and sends appropriate HTTP response
func HandleAppError(c *gin.Context, err error) {
if appErr, ok := err.(*contextutils.AppError); ok {
StandardizeAppError(c, appErr)
} else {
// Fallback for non-AppError types
StandardizeHTTPError(c, http.StatusInternalServerError, "Internal server error", err.Error())
}
}
// StandardizeAppError sends a structured error response using AppError
func StandardizeAppError(c *gin.Context, err *contextutils.AppError) {
// Map error codes to HTTP status codes
statusCode := mapErrorCodeToHTTPStatus(err.Code)
// Convert error to JSON structure
errorJSON := err.ToJSON()
// Add retryable information based on error type
errorJSON["retryable"] = contextutils.IsRetryable(err)
c.JSON(statusCode, errorJSON)
}
// StandardizeHTTPError creates consistent HTTP error responses with structured error information
func StandardizeHTTPError(c *gin.Context, _ int, message, details string) {
// Create a generic AppError for consistent response format
appErr := contextutils.NewAppError(
contextutils.ErrorCodeInternalError,
contextutils.SeverityError,
message,
details,
)
StandardizeAppError(c, appErr)
}
// ServiceUnavailable sends a 503 Service Unavailable error with a standardized payload
func ServiceUnavailable(c *gin.Context, msg string) {
appErr := contextutils.NewAppError(
contextutils.ErrorCodeServiceUnavailable,
contextutils.SeverityError,
msg,
"",
)
StandardizeAppError(c, appErr)
}
// mapErrorCodeToHTTPStatus maps AppError codes to appropriate HTTP status codes
func mapErrorCodeToHTTPStatus(code contextutils.ErrorCode) int {
switch code {
// 4xx Client Errors
case contextutils.ErrorCodeInvalidInput, contextutils.ErrorCodeMissingRequired,
contextutils.ErrorCodeInvalidFormat, contextutils.ErrorCodeValidationFailed,
contextutils.ErrorCodeOAuthStateMismatch:
return http.StatusBadRequest
case contextutils.ErrorCodeUnauthorized:
return http.StatusUnauthorized
case contextutils.ErrorCodeForbidden:
return http.StatusForbidden
case contextutils.ErrorCodeRecordNotFound, contextutils.ErrorCodeQuestionNotFound,
contextutils.ErrorCodeAssignmentNotFound:
return http.StatusNotFound
case contextutils.ErrorCodeRecordExists:
return http.StatusConflict
case contextutils.ErrorCodeSessionExpired, contextutils.ErrorCodeInvalidCredentials:
return http.StatusUnauthorized
case contextutils.ErrorCodeRateLimit:
return http.StatusTooManyRequests
// 5xx Server Errors
case contextutils.ErrorCodeInternalError:
return http.StatusInternalServerError
case contextutils.ErrorCodeServiceUnavailable, contextutils.ErrorCodeDatabaseConnection,
contextutils.ErrorCodeAIProviderUnavailable:
return http.StatusServiceUnavailable
case contextutils.ErrorCodeTimeout:
return http.StatusRequestTimeout
case contextutils.ErrorCodeDatabaseQuery, contextutils.ErrorCodeDatabaseTransaction,
contextutils.ErrorCodeForeignKeyViolation, contextutils.ErrorCodeTimestampMissingTimezone,
contextutils.ErrorCodeAIRequestFailed, contextutils.ErrorCodeAIResponseInvalid,
contextutils.ErrorCodeAIConfigInvalid, contextutils.ErrorCodeOAuthProviderError:
return http.StatusInternalServerError
// Default to internal server error for unknown codes
default:
return http.StatusInternalServerError
}
}
package middleware
import (
"encoding/json"
"fmt"
"os"
"strings"
contextutils "quizapp/internal/utils"
"github.com/xeipuuv/gojsonschema"
"gopkg.in/yaml.v2"
)
// SchemaLoader loads JSON schemas from the Swagger specification
type SchemaLoader struct {
schemas map[string]*gojsonschema.Schema
}
// NewSchemaLoader creates a new schema loader
func NewSchemaLoader() *SchemaLoader {
return &SchemaLoader{
schemas: make(map[string]*gojsonschema.Schema),
}
}
// LoadSchemasFromSwagger loads all schemas from the Swagger specification
func (sl *SchemaLoader) LoadSchemasFromSwagger(swaggerPath string) error {
// Read the Swagger file
data, err := os.ReadFile(swaggerPath)
if err != nil {
return contextutils.WrapError(err, "failed to read swagger file")
}
// Parse the Swagger spec (YAML only)
var swagger map[string]interface{}
if err := yaml.Unmarshal(data, &swagger); err != nil {
return contextutils.WrapError(err, "failed to parse swagger file as YAML")
}
fmt.Printf("â Successfully parsed swagger file as YAML\n")
// Extract components/schemas
components, ok := swagger["components"].(map[interface{}]interface{})
if !ok {
fmt.Printf("â No components section found. Available keys: %v\n", getKeys(swagger))
fmt.Printf("â Components type: %T, value: %v\n", swagger["components"], swagger["components"])
return contextutils.ErrorWithContextf("no components section found in swagger")
}
schemas, ok := components["schemas"].(map[interface{}]interface{})
if !ok {
fmt.Printf("â No schemas section found in components. Available keys: %v\n", getKeysInterface(components))
fmt.Printf("â Schemas type: %T, value: %v\n", components["schemas"], components["schemas"])
return contextutils.ErrorWithContextf("no schemas section found in swagger")
}
// Convert schemas to JSON-compatible format
jsonCompatibleSchemas := make(map[string]interface{})
for schemaName, schemaData := range schemas {
schemaNameStr, ok := schemaName.(string)
if !ok {
fmt.Printf("Warning: schema name is not a string: %v\n", schemaName)
continue
}
convertedSchema, err := convertToJSONCompatible(schemaData)
if err != nil {
fmt.Printf("Warning: failed to convert schema %s: %v\n", schemaNameStr, err)
continue
}
jsonCompatibleSchemas[schemaNameStr] = convertedSchema
}
// Load each schema
for schemaNameStr := range jsonCompatibleSchemas {
// Create a schema document with the full swagger context for $ref resolution
completeSchemaDoc := map[string]interface{}{
"$schema": "http://json-schema.org/draft-07/schema#",
"components": map[string]interface{}{
"schemas": jsonCompatibleSchemas,
},
"$ref": "#/components/schemas/" + schemaNameStr,
}
schemaBytes, err := json.Marshal(completeSchemaDoc)
if err != nil {
fmt.Printf("Warning: failed to marshal schema %s: %v\n", schemaNameStr, err)
continue
}
// Load the schema
schemaLoader := gojsonschema.NewBytesLoader(schemaBytes)
schema, err := gojsonschema.NewSchema(schemaLoader)
if err != nil {
fmt.Printf("Warning: failed to load schema %s: %v\n", schemaNameStr, err)
continue
}
sl.schemas[schemaNameStr] = schema
fmt.Printf("â Loaded schema: %s\n", schemaNameStr)
}
return nil
}
// getKeys returns the keys of a map
func getKeys(m map[string]interface{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}
// getKeysInterface returns the keys of a map with interface{} keys
func getKeysInterface(m map[interface{}]interface{}) []string {
keys := make([]string, 0, len(m))
for k := range m {
if keyStr, ok := k.(string); ok {
keys = append(keys, keyStr)
}
}
return keys
}
// convertInterfaceMapToStringMap converts a map[interface{}]interface{} to map[string]interface{}
func convertInterfaceMapToStringMap(m map[interface{}]interface{}) map[string]interface{} {
result := make(map[string]interface{})
for k, v := range m {
if keyStr, ok := k.(string); ok {
result[keyStr] = v
}
}
return result
}
// convertToJSONCompatible converts a map[interface{}]interface{} to map[string]interface{}
func convertToJSONCompatible(data interface{}) (interface{}, error) {
switch v := data.(type) {
case map[interface{}]interface{}:
result := make(map[string]interface{})
hasNullable := false
for k, val := range v {
keyStr, ok := k.(string)
if !ok {
return nil, contextutils.ErrorWithContextf("key is not a string: %v", k)
}
// Check for nullable field
if keyStr == "nullable" {
nullable, ok := val.(bool)
if ok && nullable {
hasNullable = true
continue // Skip the nullable field as we'll handle it in the type conversion
}
}
convertedVal, err := convertToJSONCompatible(val)
if err != nil {
return nil, err
}
result[keyStr] = convertedVal
}
// Handle nullable fields by converting to union type
if hasNullable {
// If there's a $ref field, create a union type with null
if ref, hasRef := result["$ref"].(string); hasRef {
// Create a union type that allows both the referenced type and null
result["oneOf"] = []interface{}{
map[string]interface{}{"$ref": ref},
map[string]interface{}{"enum": []interface{}{nil}},
}
// Remove the original $ref field
delete(result, "$ref")
} else if typeVal, hasType := result["type"].(string); hasType {
// If there's a type field, convert to array of types including null
result["type"] = []interface{}{typeVal, "null"}
}
}
return result, nil
case []interface{}:
result := make([]interface{}, len(v))
for i, val := range v {
convertedVal, err := convertToJSONCompatible(val)
if err != nil {
return nil, err
}
result[i] = convertedVal
}
return result, nil
default:
return data, nil
}
}
// ValidateData validates data against a schema
func (sl *SchemaLoader) ValidateData(data interface{}, schemaName string) error {
schema, exists := sl.schemas[schemaName]
if !exists {
return contextutils.ErrorWithContextf("schema %s not found", schemaName)
}
// Convert data to JSON
jsonData, err := json.Marshal(data)
if err != nil {
return contextutils.WrapError(err, "failed to marshal data")
}
// Create document loader
documentLoader := gojsonschema.NewBytesLoader(jsonData)
// Validate
result, err := schema.Validate(documentLoader)
if err != nil {
return contextutils.WrapError(err, "validation error")
}
if !result.Valid() {
var validationErrors []string
for _, validationErr := range result.Errors() {
validationErrors = append(validationErrors, fmt.Sprintf("%s: %s", validationErr.Field(), validationErr.Description()))
}
return contextutils.ErrorWithContextf("schema validation failed: %s", strings.Join(validationErrors, "; "))
}
return nil
}
// AutoLoadSchemas automatically loads schemas from the swagger file path
func AutoLoadSchemas() *SchemaLoader {
loader := NewSchemaLoader()
// Get swagger file path from environment variable
swaggerPath := os.Getenv("SWAGGER_FILE_PATH")
if swaggerPath == "" {
fmt.Printf("â SWAGGER_FILE_PATH environment variable not set\n")
return loader
}
if _, err := os.Stat(swaggerPath); err == nil {
if err := loader.LoadSchemasFromSwagger(swaggerPath); err != nil {
fmt.Printf("Warning: failed to load schemas from %s: %v\n", swaggerPath, err)
} else {
fmt.Printf("â Successfully loaded schemas from %s\n", swaggerPath)
return loader
}
} else {
fmt.Printf("âï Swagger file not found at %s: %v\n", swaggerPath, err)
}
return loader
}
// IsEndpointDocumented checks if an endpoint is documented in the swagger spec
func (sl *SchemaLoader) IsEndpointDocumented(path, method string) bool {
// Get swagger file path from environment variable
swaggerPath := os.Getenv("SWAGGER_FILE_PATH")
if swaggerPath == "" {
return false
}
if _, err := os.Stat(swaggerPath); err != nil {
return false
}
data, err := os.ReadFile(swaggerPath)
if err != nil {
return false
}
var swagger map[string]interface{}
// Parse as YAML
if err := yaml.Unmarshal(data, &swagger); err != nil {
return false
}
// Extract paths
paths, ok := swagger["paths"].(map[string]interface{})
if !ok {
// Try with interface{} keys
pathsInterface, ok := swagger["paths"].(map[interface{}]interface{})
if !ok {
return false
}
// Convert to string keys
paths = convertInterfaceMapToStringMap(pathsInterface)
}
// First, try exact match
pathInfo, exists := paths[path]
if exists {
pathMap, ok := pathInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
pathMapInterface, ok := pathInfo.(map[interface{}]interface{})
if !ok {
return false
}
// Convert to string keys
pathMap = convertInterfaceMapToStringMap(pathMapInterface)
}
// Look for the specific HTTP method
_, exists = pathMap[strings.ToLower(method)]
if exists {
return true
}
}
// If exact match fails, try pattern matching for path parameters
for swaggerPath := range paths {
if sl.pathMatchesPattern(path, swaggerPath) {
pathInfo := paths[swaggerPath]
pathMap, ok := pathInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
pathMapInterface, ok := pathInfo.(map[interface{}]interface{})
if !ok {
continue
}
// Convert to string keys
pathMap = convertInterfaceMapToStringMap(pathMapInterface)
}
// Look for the specific HTTP method
_, exists = pathMap[strings.ToLower(method)]
if exists {
return true
}
}
}
return false
}
// pathMatchesPattern checks if a request path matches a swagger path pattern
func (sl *SchemaLoader) pathMatchesPattern(requestPath, swaggerPath string) bool {
// Split paths into segments
requestSegments := strings.Split(requestPath, "/")
swaggerSegments := strings.Split(swaggerPath, "/")
// Paths must have the same number of segments
if len(requestSegments) != len(swaggerSegments) {
return false
}
// Compare each segment
for i, swaggerSegment := range swaggerSegments {
requestSegment := requestSegments[i]
// If swagger segment is a parameter (starts with { and ends with })
if strings.HasPrefix(swaggerSegment, "{") && strings.HasSuffix(swaggerSegment, "}") {
// Any value is acceptable for parameters
continue
}
// Otherwise, segments must match exactly
if swaggerSegment != requestSegment {
return false
}
}
return true
}
// DetermineRequestSchemaFromPath automatically determines the schema name from the API path and method
func (sl *SchemaLoader) DetermineRequestSchemaFromPath(path, method string) string {
// Get swagger file path from environment variable
swaggerPath := os.Getenv("SWAGGER_FILE_PATH")
if swaggerPath == "" {
fmt.Printf("DEBUG: SWAGGER_FILE_PATH not set\n")
return ""
}
if _, err := os.Stat(swaggerPath); err != nil {
fmt.Printf("DEBUG: Swagger file not found: %s\n", swaggerPath)
return ""
}
data, err := os.ReadFile(swaggerPath)
if err != nil {
fmt.Printf("DEBUG: Failed to read swagger file: %v\n", err)
return ""
}
var swagger map[string]interface{}
// Parse as YAML
if err := yaml.Unmarshal(data, &swagger); err != nil {
fmt.Printf("DEBUG: Failed to parse swagger file: %v\n", err)
return ""
}
// Extract paths
paths, ok := swagger["paths"].(map[string]interface{})
if !ok {
// Try with interface{} keys
pathsInterface, ok := swagger["paths"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
paths = convertInterfaceMapToStringMap(pathsInterface)
}
// Look for the specific path
pathInfo, exists := paths[path]
if !exists {
return ""
}
pathMap, ok := pathInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
pathMapInterface, ok := pathInfo.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
pathMap = convertInterfaceMapToStringMap(pathMapInterface)
}
// Look for the specific HTTP method
methodInfo, exists := pathMap[strings.ToLower(method)]
if !exists {
return ""
}
methodMap, ok := methodInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
methodMapInterface, ok := methodInfo.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
methodMap = convertInterfaceMapToStringMap(methodMapInterface)
}
// Extract the request body schema
requestBody, exists := methodMap["requestBody"]
if !exists {
return ""
}
requestBodyMap, ok := requestBody.(map[string]interface{})
if !ok {
// Try with interface{} keys
requestBodyMapInterface, ok := requestBody.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
requestBodyMap = convertInterfaceMapToStringMap(requestBodyMapInterface)
}
// Extract content
content, ok := requestBodyMap["content"].(map[string]interface{})
if !ok {
// Try with interface{} keys
contentInterface, ok := requestBodyMap["content"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
content = convertInterfaceMapToStringMap(contentInterface)
}
// Look for application/json content
jsonContent, exists := content["application/json"]
if !exists {
return ""
}
jsonContentMap, ok := jsonContent.(map[string]interface{})
if !ok {
// Try with interface{} keys
jsonContentMapInterface, ok := jsonContent.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
jsonContentMap = convertInterfaceMapToStringMap(jsonContentMapInterface)
}
// Extract schema
schema, exists := jsonContentMap["schema"]
if !exists {
return ""
}
schemaMap, ok := schema.(map[string]interface{})
if !ok {
// Try with interface{} keys
schemaMapInterface, ok := schema.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
schemaMap = convertInterfaceMapToStringMap(schemaMapInterface)
}
// Extract $ref
ref, exists := schemaMap["$ref"]
if !exists {
return ""
}
refStr, ok := ref.(string)
if !ok {
return ""
}
// Extract schema name from $ref
// $ref format: "#/components/schemas/SchemaName"
parts := strings.Split(refStr, "/")
if len(parts) < 4 {
return ""
}
return parts[len(parts)-1]
}
// DetermineSchemaFromPath determines the schema name for a given path and HTTP method
// by parsing the swagger file and looking up the response schema for the 200 status code.
func (sl *SchemaLoader) DetermineSchemaFromPath(path, method string) string {
// Get swagger file path from environment variable
swaggerPath := os.Getenv("SWAGGER_FILE_PATH")
if swaggerPath == "" {
return ""
}
if _, err := os.Stat(swaggerPath); err != nil {
return ""
}
data, err := os.ReadFile(swaggerPath)
if err != nil {
return ""
}
var swagger map[string]interface{}
// Parse as YAML
if err := yaml.Unmarshal(data, &swagger); err != nil {
return ""
}
// Extract paths
paths, ok := swagger["paths"].(map[string]interface{})
if !ok {
// Try with interface{} keys
pathsInterface, ok := swagger["paths"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
paths = convertInterfaceMapToStringMap(pathsInterface)
}
// Look for the specific path
pathInfo, exists := paths[path]
if !exists {
return ""
}
pathMap, ok := pathInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
pathMapInterface, ok := pathInfo.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
pathMap = convertInterfaceMapToStringMap(pathMapInterface)
}
// Look for the specific HTTP method
methodInfo, exists := pathMap[strings.ToLower(method)]
if !exists {
return ""
}
methodMap, ok := methodInfo.(map[string]interface{})
if !ok {
// Try with interface{} keys
methodMapInterface, ok := methodInfo.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
methodMap = convertInterfaceMapToStringMap(methodMapInterface)
}
// Extract the response schema
responses, ok := methodMap["responses"].(map[string]interface{})
if !ok {
// Try with interface{} keys
responsesInterface, ok := methodMap["responses"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
responses = convertInterfaceMapToStringMap(responsesInterface)
}
// Look for 200 response
response200, exists := responses["200"]
if !exists {
return ""
}
responseMap, ok := response200.(map[string]interface{})
if !ok {
// Try with interface{} keys
responseMapInterface, ok := response200.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
responseMap = convertInterfaceMapToStringMap(responseMapInterface)
}
// Extract content
content, ok := responseMap["content"].(map[string]interface{})
if !ok {
// Try with interface{} keys
contentInterface, ok := responseMap["content"].(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
content = convertInterfaceMapToStringMap(contentInterface)
}
// Look for application/json
jsonContent, exists := content["application/json"]
if !exists {
return ""
}
jsonMap, ok := jsonContent.(map[string]interface{})
if !ok {
// Try with interface{} keys
jsonMapInterface, ok := jsonContent.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
jsonMap = convertInterfaceMapToStringMap(jsonMapInterface)
}
// Extract schema reference
schema, exists := jsonMap["schema"]
if !exists {
return ""
}
schemaMap, ok := schema.(map[string]interface{})
if !ok {
// Try with interface{} keys
schemaMapInterface, ok := schema.(map[interface{}]interface{})
if !ok {
return ""
}
// Convert to string keys
schemaMap = convertInterfaceMapToStringMap(schemaMapInterface)
}
// Extract $ref
ref, exists := schemaMap["$ref"]
if !exists {
return ""
}
refStr, ok := ref.(string)
if !ok {
return ""
}
// Extract schema name from $ref (e.g., "#/components/schemas/DashboardResponse")
if strings.HasPrefix(refStr, "#/components/schemas/") {
schemaName := strings.TrimPrefix(refStr, "#/components/schemas/")
return schemaName
}
return ""
}
package middleware
import (
"bytes"
"encoding/json"
"fmt"
"io"
"math"
"net/http"
"strings"
"quizapp/internal/observability"
"github.com/gin-gonic/gin"
)
// Global schema loader instance
var globalSchemaLoader *SchemaLoader
// initSchemaLoader initializes the global schema loader once
func initSchemaLoader() *SchemaLoader {
if globalSchemaLoader == nil {
globalSchemaLoader = AutoLoadSchemas()
}
return globalSchemaLoader
}
// ResponseValidationMiddleware creates middleware that automatically validates responses
func ResponseValidationMiddleware(logger *observability.Logger) gin.HandlerFunc {
// Initialize schema loader once
schemaLoader := initSchemaLoader()
return func(c *gin.Context) {
// Start tracing span for validation
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "response_validation")
defer span.End()
// Store the original response writer
originalWriter := c.Writer
// Create a custom response writer that captures the response
responseWriter := &responseCaptureWriter{
ResponseWriter: originalWriter,
body: &bytes.Buffer{},
status: 0,
}
// Replace the response writer
c.Writer = responseWriter
// Continue to the next handler
c.Next()
// After the response is written, validate it
statusCode := responseWriter.status
if statusCode == 0 {
statusCode = c.Writer.Status()
}
if statusCode == http.StatusOK {
// Try to parse the response as JSON
var responseData interface{}
err := json.Unmarshal(responseWriter.body.Bytes(), &responseData)
if err == nil {
// Automatically determine schema name from the endpoint
schemaName := schemaLoader.DetermineSchemaFromPath(c.Request.URL.Path, c.Request.Method)
// Add tracing attributes
span.SetAttributes(
observability.AttributeSearch(c.Request.URL.Path),
observability.AttributeTypeFilter(c.Request.Method),
)
if schemaName != "" {
span.SetAttributes(observability.AttributeSearch(schemaName))
if err := schemaLoader.ValidateData(responseData, schemaName); err != nil {
// Log the validation error and add tracing attributes
span.SetAttributes(
observability.AttributeTypeFilter("validation_failed"),
)
// Log the validation error and fail the request
logger.Error(ctx, "Response validation failed", err, map[string]interface{}{
"method": c.Request.Method,
"path": c.Request.URL.Path,
"schema_name": schemaName,
"error": err.Error(),
"response_data": responseWriter.body.String()[:int(math.Min(200, float64(responseWriter.body.Len())))],
})
// Write a 400 error response instead of the original response
c.Writer = originalWriter
c.Writer.WriteHeader(http.StatusBadRequest)
_ = json.NewEncoder(c.Writer).Encode(gin.H{
"error": "Response validation failed",
"message": "API response does not match the specification",
"method": c.Request.Method,
"path": c.Request.URL.Path,
"schema": schemaName,
"details": err.Error(),
})
return
}
// Add success tracing attributes
span.SetAttributes(
observability.AttributeTypeFilter("validation_passed"),
)
// Write the buffered response to the real writer
c.Writer = originalWriter
c.Writer.WriteHeader(statusCode)
_, _ = c.Writer.Write(responseWriter.body.Bytes())
return
}
// No schema found for this endpoint
span.SetAttributes(
observability.AttributeTypeFilter("no_schema_found"),
)
logger.Warn(ctx, "No schema found for endpoint", map[string]interface{}{
"method": c.Request.Method,
"path": c.Request.URL.Path,
})
// Write the buffered response to the real writer
c.Writer = originalWriter
c.Writer.WriteHeader(statusCode)
_, _ = c.Writer.Write(responseWriter.body.Bytes())
return
}
// Failed to parse JSON response
span.SetAttributes(
observability.AttributeTypeFilter("json_parse_failed"),
)
logger.Error(ctx, "Failed to parse JSON response", err, map[string]interface{}{
"method": c.Request.Method,
"path": c.Request.URL.Path,
})
// Write the buffered response to the real writer
c.Writer = originalWriter
c.Writer.WriteHeader(statusCode)
_, _ = c.Writer.Write(responseWriter.body.Bytes())
return
}
// Non-200 status code, skip validation
span.SetAttributes(
observability.AttributeTypeFilter("non_200_status"),
)
// Write the buffered response to the real writer
c.Writer = originalWriter
c.Writer.WriteHeader(statusCode)
_, _ = c.Writer.Write(responseWriter.body.Bytes())
}
}
// responseCaptureWriter captures the response body for validation
// Add a status field to track the status code
type responseCaptureWriter struct {
gin.ResponseWriter
body *bytes.Buffer
status int
}
func (w *responseCaptureWriter) WriteHeader(statusCode int) {
w.status = statusCode
}
func (w *responseCaptureWriter) Write(b []byte) (int, error) {
return w.body.Write(b)
}
// isStaticFile checks if a path is a static file that should be allowed to pass through
func isStaticFile(path string) bool {
staticPaths := []string{
"/swagger.yaml",
"/swaggerz",
"/configz",
"/",
}
for _, staticPath := range staticPaths {
if path == staticPath {
return true
}
}
// Also allow paths that start with /backend/ (static assets)
if strings.HasPrefix(path, "/backend/") {
return true
}
return false
}
// RequestValidationMiddleware creates middleware that prevents undocumented API calls
func RequestValidationMiddleware(logger *observability.Logger) gin.HandlerFunc {
// Initialize schema loader once
schemaLoader := initSchemaLoader()
return func(c *gin.Context) {
// Start tracing span for request validation
ctx, span := observability.TraceHandlerFunction(c.Request.Context(), "request_validation")
defer span.End()
// Check if the endpoint exists in the swagger spec
path := c.Request.URL.Path
method := c.Request.Method
// Log all requests for debugging
logger.Info(ctx, "Request validation middleware called", map[string]interface{}{
"method": method,
"path": path,
})
// Add tracing attributes
span.SetAttributes(
observability.AttributeSearch(path),
observability.AttributeTypeFilter(method),
)
// Allow static files to pass through
if isStaticFile(path) {
// Continue to the next handler
c.Next()
return
}
// Check if this endpoint is documented in swagger
if !schemaLoader.IsEndpointDocumented(path, method) {
// Log the undocumented API call
logger.Warn(ctx, "Undocumented API call attempted", map[string]interface{}{
"method": method,
"path": path,
"ip": c.ClientIP(),
"user_agent": c.Request.UserAgent(),
})
// Return 404 for undocumented endpoints
c.JSON(http.StatusNotFound, gin.H{
"error": "Endpoint not found",
"message": "The requested endpoint is not documented in the API specification",
})
c.Abort()
return
}
// Endpoint is documented, continue
span.SetAttributes(
observability.AttributeTypeFilter("endpoint_documented"),
)
// Validate request body against schema for POST/PUT/PATCH requests
if method == "POST" || method == "PUT" || method == "PATCH" {
// Determine the request body schema name for this endpoint
schemaName := schemaLoader.DetermineRequestSchemaFromPath(path, method)
// Log the schema determination for debugging
logger.Info(ctx, "Request validation schema determined", map[string]interface{}{
"method": method,
"path": path,
"schema_name": schemaName,
})
// Log when no schema is found
if schemaName == "" {
logger.Warn(ctx, "No schema found for endpoint", map[string]interface{}{
"method": method,
"path": path,
})
}
// Restore the request body so handlers can read it
body, err := c.GetRawData()
if err == nil && len(body) > 0 {
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
}
if schemaName != "" {
// Read the request body without consuming it
body, err := c.GetRawData()
if err == nil && len(body) > 0 {
// Restore the request body so handlers can read it
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
// Log the raw request body for debugging
logger.Info(ctx, "Request body received", map[string]interface{}{
"method": method,
"path": path,
"schema_name": schemaName,
"body": string(body),
})
// Parse the JSON
var requestData interface{}
if err := json.Unmarshal(body, &requestData); err == nil {
// Validate the request data against the schema
if err := schemaLoader.ValidateData(requestData, schemaName); err != nil {
// Log the validation error and the request data
logger.Error(ctx, "Request validation failed", err, map[string]interface{}{
"method": method,
"path": path,
"schema_name": schemaName,
"error": err.Error(),
"request_data": requestData,
"raw_body": string(body),
})
// Add validation error details to tracing span
span.SetAttributes(
observability.AttributeTypeFilter("validation_failed"),
observability.AttributeSearch(path),
observability.AttributeTypeFilter(method),
observability.AttributeTypeFilter(schemaName),
observability.AttributeTypeFilter("validation_error:"+err.Error()),
observability.AttributeTypeFilter("request_data:"+fmt.Sprintf("%v", requestData)),
observability.AttributeTypeFilter("raw_body:"+string(body)),
)
// Print a concise summary to stdout for test debug
fmt.Printf("\n[VALIDATION ERROR] %v\n[REQUEST DATA] %v\n[RAW BODY] %s\n\n", err, requestData, string(body))
// Return 400 for invalid request data
c.JSON(http.StatusBadRequest, gin.H{
"error": "Invalid request data",
"message": "Request data does not match the API specification",
"method": method,
"path": path,
"schema": schemaName,
"details": err.Error(),
})
c.Abort()
return
}
}
// Restore the request body so handlers can read it
c.Request.Body = io.NopCloser(bytes.NewBuffer(body))
}
}
}
// Continue to the next handler
c.Next()
}
}
// Package models defines data structures used throughout the quiz application.
package models
import (
"database/sql"
"encoding/json"
"time"
"quizapp/internal/api"
)
// User represents a user in the system
type User struct {
ID int `json:"id" yaml:"id"`
Username string `json:"username" yaml:"username"`
Email sql.NullString `json:"email" yaml:"email"`
Timezone sql.NullString `json:"timezone" yaml:"timezone"`
PasswordHash sql.NullString `json:"-" yaml:"-"` // Omit from JSON responses
LastActive sql.NullTime `json:"last_active" yaml:"last_active"`
PreferredLanguage sql.NullString `json:"preferred_language" yaml:"preferred_language"`
CurrentLevel sql.NullString `json:"current_level" yaml:"current_level"`
AIProvider sql.NullString `json:"ai_provider" yaml:"ai_provider"`
AIModel sql.NullString `json:"ai_model" yaml:"ai_model"`
AIEnabled sql.NullBool `json:"ai_enabled" yaml:"ai_enabled"`
AIAPIKey sql.NullString `json:"-" yaml:"ai_api_key"` // Omit from JSON responses
CreatedAt time.Time `json:"created_at" yaml:"created_at"`
UpdatedAt time.Time `json:"updated_at" yaml:"updated_at"`
Roles []Role `json:"roles,omitempty" yaml:"roles,omitempty"`
}
// Role represents a role in the system
type Role struct {
ID int `json:"id" yaml:"id"`
Name string `json:"name" yaml:"name"`
Description string `json:"description" yaml:"description"`
CreatedAt time.Time `json:"created_at" yaml:"created_at"`
UpdatedAt time.Time `json:"updated_at" yaml:"updated_at"`
}
// UserRole represents the mapping between users and roles
type UserRole struct {
ID int `json:"id" yaml:"id"`
UserID int `json:"user_id" yaml:"user_id"`
RoleID int `json:"role_id" yaml:"role_id"`
CreatedAt time.Time `json:"created_at" yaml:"created_at"`
}
// MarshalJSON customizes JSON marshaling for User to handle sql.NullString and sql.NullTime properly
func (u User) MarshalJSON() (result0 []byte, err error) { // Create a struct with the desired JSON structure
return json.Marshal(&struct {
ID int `json:"id"`
Username string `json:"username"`
Email *string `json:"email"`
Timezone *string `json:"timezone"`
LastActive *time.Time `json:"last_active"`
PreferredLanguage *string `json:"preferred_language"`
CurrentLevel *string `json:"current_level"`
AIProvider *string `json:"ai_provider"`
AIModel *string `json:"ai_model"`
AIEnabled *bool `json:"ai_enabled"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Roles []Role `json:"roles,omitempty"`
}{
ID: u.ID,
Username: u.Username,
Email: nullStringToPointer(u.Email),
Timezone: nullStringToPointer(u.Timezone),
LastActive: nullTimeToPointer(u.LastActive),
PreferredLanguage: nullStringToPointer(u.PreferredLanguage),
CurrentLevel: nullStringToPointer(u.CurrentLevel),
AIProvider: nullStringToPointer(u.AIProvider),
AIModel: nullStringToPointer(u.AIModel),
AIEnabled: nullBoolToPointer(u.AIEnabled),
CreatedAt: u.CreatedAt,
UpdatedAt: u.UpdatedAt,
Roles: u.Roles,
})
}
// Helper functions for converting sql.Null types to pointers
func nullStringToPointer(ns sql.NullString) *string {
if ns.Valid {
return &ns.String
}
return nil
}
func nullTimeToPointer(nt sql.NullTime) *time.Time {
if nt.Valid {
return &nt.Time
}
return nil
}
func nullBoolToPointer(nb sql.NullBool) *bool {
if nb.Valid {
return &nb.Bool
}
return nil
}
func nullInt32ToPointer(ni sql.NullInt32) *int32 {
if ni.Valid {
return &ni.Int32
}
return nil
}
// UserAPIKey represents an API key for a specific provider for a user
type UserAPIKey struct {
ID int `json:"id"`
UserID int `json:"user_id"`
Provider string `json:"provider"`
APIKey string `json:"-"` // Omit from JSON responses for security
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}
// Question represents a quiz question
type Question struct {
ID int `json:"id" yaml:"id"`
Type QuestionType `json:"type" yaml:"type"`
Language string `json:"language" yaml:"language"`
Level string `json:"level" yaml:"level"`
DifficultyScore float64 `json:"difficulty_score" yaml:"difficulty_score"`
Content map[string]interface{} `json:"content" yaml:"content"`
CorrectAnswer int `json:"correct_answer" yaml:"correct_answer"`
Explanation string `json:"explanation,omitempty" yaml:"explanation"`
CreatedAt time.Time `json:"created_at" yaml:"created_at"`
Status QuestionStatus `json:"status" yaml:"status"`
// Test data field for specifying which users should have this question
Users []string `json:"users,omitempty" yaml:"users,omitempty"`
// Variety elements for question generation diversity
TopicCategory string `json:"topic_category,omitempty" yaml:"topic_category"`
GrammarFocus string `json:"grammar_focus,omitempty" yaml:"grammar_focus"`
VocabularyDomain string `json:"vocabulary_domain,omitempty" yaml:"vocabulary_domain"`
Scenario string `json:"scenario,omitempty" yaml:"scenario"`
StyleModifier string `json:"style_modifier,omitempty" yaml:"style_modifier"`
DifficultyModifier string `json:"difficulty_modifier,omitempty" yaml:"difficulty_modifier"`
TimeContext string `json:"time_context,omitempty" yaml:"time_context"`
}
// UserQuestion represents the mapping between users and questions
type UserQuestion struct {
ID int `json:"id"`
UserID int `json:"user_id"`
QuestionID int `json:"question_id"`
CreatedAt time.Time `json:"created_at"`
}
// QuestionReport represents a report of a question by a user
type QuestionReport struct {
ID int `json:"id"`
QuestionID int `json:"question_id"`
ReportedByUserID int `json:"reported_by_user_id"`
ReportReason string `json:"report_reason"`
CreatedAt time.Time `json:"created_at"`
}
// QuestionType represents the type of question
type QuestionType string
// QuestionStatus represents the status of a question
type QuestionStatus string
const (
// QuestionStatusActive is for questions that are in active use
QuestionStatusActive QuestionStatus = "active"
// QuestionStatusReported is for questions that have been reported as incorrect
QuestionStatusReported QuestionStatus = "reported"
)
// Question types supported by the system
const (
// Vocabulary represents vocabulary in context questions
Vocabulary QuestionType = "vocabulary"
// FillInBlank represents fill-in-the-blank questions
FillInBlank QuestionType = "fill_blank"
// QuestionAnswer represents simple Q&A questions
QuestionAnswer QuestionType = "qa"
// ReadingComprehension represents reading comprehension questions
ReadingComprehension QuestionType = "reading_comprehension"
)
// UserResponse represents a user's answer to a question
type UserResponse struct {
ID int `json:"id" yaml:"id"`
UserID int `json:"user_id" yaml:"user_id"`
QuestionID int `json:"question_id" yaml:"question_id"`
UserAnswerIndex int `json:"user_answer_index" yaml:"user_answer_index"`
IsCorrect bool `json:"is_correct" yaml:"is_correct"`
ResponseTimeMs int `json:"response_time_ms" yaml:"response_time_ms"`
ConfidenceLevel sql.NullInt32 `json:"confidence_level" yaml:"confidence_level"`
CreatedAt time.Time `json:"created_at" yaml:"created_at"`
}
// MarshalJSON customizes JSON marshaling for UserResponse to handle sql.NullInt32 properly
func (ur UserResponse) MarshalJSON() (result0 []byte, err error) {
return json.Marshal(&struct {
ID int `json:"id"`
UserID int `json:"user_id"`
QuestionID int `json:"question_id"`
UserAnswerIndex int `json:"user_answer_index"`
IsCorrect bool `json:"is_correct"`
ResponseTimeMs int `json:"response_time_ms"`
ConfidenceLevel *int32 `json:"confidence_level"`
CreatedAt time.Time `json:"created_at"`
}{
ID: ur.ID,
UserID: ur.UserID,
QuestionID: ur.QuestionID,
UserAnswerIndex: ur.UserAnswerIndex,
IsCorrect: ur.IsCorrect,
ResponseTimeMs: ur.ResponseTimeMs,
ConfidenceLevel: nullInt32ToPointer(ur.ConfidenceLevel),
CreatedAt: ur.CreatedAt,
})
}
// PerformanceMetrics tracks user performance across different categories
type PerformanceMetrics struct {
ID int `json:"id"`
UserID int `json:"user_id"`
Topic string `json:"topic"`
Language string `json:"language"`
Level string `json:"level"`
TotalAttempts int `json:"total_attempts"`
CorrectAttempts int `json:"correct_attempts"`
AverageResponseTimeMs float64 `json:"average_response_time_ms"`
DifficultyAdjustment float64 `json:"difficulty_adjustment"`
LastUpdated time.Time `json:"last_updated"`
}
// AccuracyRate calculates the accuracy percentage
func (pm *PerformanceMetrics) AccuracyRate() float64 {
if pm.TotalAttempts == 0 {
return 0.0
}
return float64(pm.CorrectAttempts) / float64(pm.TotalAttempts) * 100
}
// QuestionRequest represents a request for a new question
type QuestionRequest struct {
UserID int `json:"user_id"`
Language string `json:"language"`
Level string `json:"level"`
QuestionType QuestionType `json:"question_type,omitempty"`
}
// AnswerRequest represents a user's answer submission
type AnswerRequest struct {
QuestionID int `json:"question_id"`
UserAnswer string `json:"user_answer"`
ResponseTimeMs int `json:"response_time_ms"`
}
// AnswerResponse represents the response to an answer submission
type AnswerResponse struct {
IsCorrect bool `json:"is_correct"`
CorrectAnswer string `json:"correct_answer"`
UserAnswer string `json:"user_answer"`
Explanation string `json:"explanation"`
NextDifficulty string `json:"next_difficulty,omitempty"`
}
// GetCorrectAnswerText returns the text of the correct answer from the question content
func (q *Question) GetCorrectAnswerText() string {
if optionsRaw, ok := q.Content["options"]; ok {
if options, ok := optionsRaw.([]interface{}); ok {
if q.CorrectAnswer >= 0 && q.CorrectAnswer < len(options) {
if optStr, ok := options[q.CorrectAnswer].(string); ok {
return optStr
}
}
}
}
return ""
}
// UserSettings represents user preference settings
type UserSettings struct {
Language string `json:"language" yaml:"language"`
Level string `json:"level" yaml:"level"`
AIProvider string `json:"ai_provider" yaml:"ai_provider"`
AIModel string `json:"ai_model" yaml:"ai_model"`
AIEnabled bool `json:"ai_enabled" yaml:"ai_enabled"`
AIAPIKey string `json:"api_key" yaml:"ai_api_key"`
}
// UserLearningPreferences represents user learning preferences and settings
type UserLearningPreferences struct {
ID int `json:"id" db:"id"`
UserID int `json:"user_id" db:"user_id"`
PreferredLanguage string `json:"preferred_language" db:"preferred_language"`
CurrentLevel string `json:"current_level" db:"current_level"`
AIProvider string `json:"ai_provider" db:"ai_provider"`
AIModel string `json:"ai_model" db:"ai_model"`
AIEnabled bool `json:"ai_enabled" db:"ai_enabled"`
AIAPIKey string `json:"-" db:"ai_api_key"` // Omit from JSON for security
DailyGoal int `json:"daily_goal" db:"daily_goal"`
WeeklyGoal int `json:"weekly_goal" db:"weekly_goal"`
PreferredQuestionType string `json:"preferred_question_type" db:"preferred_question_type"`
PreferredQuestionTypes []string `json:"preferred_question_types" db:"preferred_question_types"`
PreferredDifficultyLevel string `json:"preferred_difficulty_level" db:"preferred_difficulty_level"`
PreferredTopics []string `json:"preferred_topics" db:"preferred_topics"`
PreferredQuestionCount int `json:"preferred_question_count" db:"preferred_question_count"`
SpacedRepetitionEnabled bool `json:"spaced_repetition_enabled" db:"spaced_repetition_enabled"`
AdaptiveDifficultyEnabled bool `json:"adaptive_difficulty_enabled" db:"adaptive_difficulty_enabled"`
FocusOnWeakAreas bool `json:"focus_on_weak_areas" db:"focus_on_weak_areas"`
IncludeReviewQuestions bool `json:"include_review_questions" db:"include_review_questions"`
FreshQuestionRatio float64 `json:"fresh_question_ratio" db:"fresh_question_ratio"`
KnownQuestionPenalty float64 `json:"known_question_penalty" db:"known_question_penalty"`
ReviewIntervalDays int `json:"review_interval_days" db:"review_interval_days"`
WeakAreaBoost float64 `json:"weak_area_boost" db:"weak_area_boost"`
StudyTime string `json:"study_time" db:"study_time"`
DailyReminderEnabled bool `json:"daily_reminder_enabled" db:"daily_reminder_enabled"`
// Preferred TTS voice (e.g., it-IT-IsabellaNeural)
TTSVoice string `json:"tts_voice" db:"tts_voice"`
LastDailyReminderSent *time.Time `json:"last_daily_reminder_sent" db:"last_daily_reminder_sent"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// UserProgress represents a user's overall progress
type UserProgress struct {
CurrentLevel string `json:"current_level"`
TotalQuestions int `json:"total_questions"`
CorrectAnswers int `json:"correct_answers"`
AccuracyRate float64 `json:"accuracy_rate"`
PerformanceByTopic map[string]*PerformanceMetrics `json:"performance_by_topic"`
WeakAreas []string `json:"weak_areas"`
RecentActivity []UserResponse `json:"recent_activity"`
SuggestedLevel string `json:"suggested_level,omitempty"`
}
// AIQuestionGenRequest represents a request to the AI service for question generation
type AIQuestionGenRequest struct {
Language string `json:"language"`
Level string `json:"level"`
QuestionType QuestionType `json:"question_type"`
Count int `json:"count"`
RecentQuestionHistory []string `json:"-"` // Don't include in JSON, internal use
}
// AIChatRequest represents a request to the AI service for a new chat feature
type AIChatRequest struct {
Language string
Level string
QuestionType QuestionType // Question type for context
Question string
Options []string
Passage string // For reading comprehension
UserAnswer string // Optional
CorrectAnswer string // Optional
IsCorrect *bool // Optional
UserMessage string
ConversationHistory []ChatMessage `json:"conversation_history,omitempty"`
RecentQuestionHistory []string `json:"-"` // Don't include in JSON, internal use
}
// ChatMessage represents a single message in the chat conversation
type ChatMessage struct {
Role api.ChatMessageRole `json:"role"` // "user" or "assistant"
Content string `json:"content"` // The message content
}
// AIExplanationRequest represents a request for an explanation of a wrong answer
type AIExplanationRequest struct {
Question string `json:"question"`
UserAnswer string `json:"user_answer"`
CorrectAnswer string `json:"correct_answer"`
Language string `json:"language"`
Level string `json:"level"`
}
// MarshalContentToJSON serializes the question content to JSON string
func (q *Question) MarshalContentToJSON() (result0 string, err error) {
data, err := json.Marshal(q.Content)
return string(data), err
}
// UnmarshalContentFromJSON deserializes JSON string into question content
func (q *Question) UnmarshalContentFromJSON(data string) error {
return json.Unmarshal([]byte(data), &q.Content)
}
// WorkerSettings represents worker configuration settings stored in database
type WorkerSettings struct {
ID int `json:"id" db:"id"`
SettingKey string `json:"setting_key" db:"setting_key"`
SettingValue string `json:"setting_value" db:"setting_value"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// WorkerStatus represents worker health and activity status
type WorkerStatus struct {
ID int `json:"id" db:"id"`
WorkerInstance string `json:"worker_instance" db:"worker_instance"`
IsRunning bool `json:"is_running" db:"is_running"`
IsPaused bool `json:"is_paused" db:"is_paused"`
CurrentActivity sql.NullString `json:"current_activity" db:"current_activity"`
LastHeartbeat sql.NullTime `json:"last_heartbeat" db:"last_heartbeat"`
LastRunStart sql.NullTime `json:"last_run_start" db:"last_run_start"`
LastRunEnd sql.NullTime `json:"last_run_end" db:"last_run_end"`
LastRunFinish sql.NullTime `json:"last_run_finish" db:"last_run_finish"`
LastRunError sql.NullString `json:"last_run_error" db:"last_run_error"`
TotalQuestionsProcessed int `json:"total_questions_processed" db:"total_questions_processed"`
TotalQuestionsGenerated int `json:"total_questions_generated" db:"total_questions_generated"`
TotalRuns int `json:"total_runs" db:"total_runs"`
CreatedAt time.Time `json:"created_at" db:"created_at"`
UpdatedAt time.Time `json:"updated_at" db:"updated_at"`
}
// MarshalJSON customizes JSON marshaling for WorkerStatus to handle sql.NullString and sql.NullTime properly
func (ws WorkerStatus) MarshalJSON() (result0 []byte, err error) {
return json.Marshal(&struct {
ID int `json:"id"`
WorkerInstance string `json:"worker_instance"`
IsRunning bool `json:"is_running"`
IsPaused bool `json:"is_paused"`
CurrentActivity *string `json:"current_activity"`
LastHeartbeat *time.Time `json:"last_heartbeat"`
LastRunStart *time.Time `json:"last_run_start"`
LastRunEnd *time.Time `json:"last_run_end"`
LastRunFinish *time.Time `json:"last_run_finish"`
LastRunError *string `json:"last_run_error"`
TotalQuestionsProcessed int `json:"total_questions_processed"`
TotalQuestionsGenerated int `json:"total_questions_generated"`
TotalRuns int `json:"total_runs"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
}{
ID: ws.ID,
WorkerInstance: ws.WorkerInstance,
IsRunning: ws.IsRunning,
IsPaused: ws.IsPaused,
CurrentActivity: nullStringToPointer(ws.CurrentActivity),
LastHeartbeat: nullTimeToPointer(ws.LastHeartbeat),
LastRunStart: nullTimeToPointer(ws.LastRunStart),
LastRunEnd: nullTimeToPointer(ws.LastRunEnd),
LastRunFinish: nullTimeToPointer(ws.LastRunFinish),
LastRunError: nullStringToPointer(ws.LastRunError),
TotalQuestionsProcessed: ws.TotalQuestionsProcessed,
TotalQuestionsGenerated: ws.TotalQuestionsGenerated,
TotalRuns: ws.TotalRuns,
CreatedAt: ws.CreatedAt,
UpdatedAt: ws.UpdatedAt,
})
}
package observability
import (
"context"
"fmt"
"quizapp/internal/models"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
var globalTracer trace.Tracer
// InitGlobalTracer initializes the global tracer for the application.
func InitGlobalTracer() {
globalTracer = otel.Tracer("quiz-app")
}
// GetGlobalTracer returns the global tracer instance for the application.
func GetGlobalTracer() trace.Tracer {
if globalTracer == nil {
// Fallback to default tracer if not initialized
globalTracer = otel.Tracer("quiz-app")
}
return globalTracer
}
// TraceFunction starts a new span with a descriptive name for the given service and function.
func TraceFunction(ctx context.Context, serviceName, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
tracer := GetGlobalTracer()
spanName := fmt.Sprintf("%s.%s", serviceName, functionName)
return tracer.Start(ctx, spanName, trace.WithAttributes(attributes...))
}
// TraceFunctionWithErrorHandling starts a new span and automatically adds error attributes if the function panics or returns an error.
func TraceFunctionWithErrorHandling(ctx context.Context, serviceName, functionName string, fn func() error, attributes ...attribute.KeyValue) error {
_, span := TraceFunction(ctx, serviceName, functionName, attributes...)
defer func() {
if err := recover(); err != nil {
span.SetAttributes(
attribute.Bool("error", true),
attribute.String("error.type", "panic"),
attribute.String("error.message", fmt.Sprintf("%v", err)),
)
span.End()
panic(err) // re-panic
}
}()
err := fn()
if err != nil {
span.SetAttributes(
attribute.Bool("error", true),
attribute.String("error.message", err.Error()),
)
}
span.End()
return err
}
// TraceAIFunction starts a new span for an AI service function.
func TraceAIFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "ai", functionName, attributes...)
}
// TraceUserFunction starts a new span for a user service function.
func TraceUserFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "user", functionName, attributes...)
}
// TraceQuestionFunction starts a new span for a question service function.
func TraceQuestionFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "question", functionName, attributes...)
}
// TraceWorkerFunction starts a new span for a worker service function.
func TraceWorkerFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "worker", functionName, attributes...)
}
// TraceLearningFunction starts a new span for a learning service function.
func TraceLearningFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "learning", functionName, attributes...)
}
// TraceHandlerFunction starts a new span for a handler function.
func TraceHandlerFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "handler", functionName, attributes...)
}
// TraceVarietyFunction starts a new span for a variety service function.
func TraceVarietyFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "variety", functionName, attributes...)
}
// TraceOAuthFunction starts a new span for an OAuth service function.
func TraceOAuthFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "oauth", functionName, attributes...)
}
// TraceCleanupFunction starts a new span for a cleanup service function.
func TraceCleanupFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "cleanup", functionName, attributes...)
}
// TraceDatabaseFunction starts a new span for a database function.
func TraceDatabaseFunction(ctx context.Context, functionName string, attributes ...attribute.KeyValue) (context.Context, trace.Span) {
return TraceFunction(ctx, "database", functionName, attributes...)
}
// AttributeQuestion returns a tracing attribute for a question's ID.
func AttributeQuestion(q *models.Question) attribute.KeyValue {
return attribute.String("question.id", fmt.Sprintf("%d", q.ID))
}
// AttributeQuestionID returns a tracing attribute for a question ID.
func AttributeQuestionID(id int) attribute.KeyValue {
return attribute.Int("question.id", id)
}
// AttributeUserID returns a tracing attribute for a user ID.
func AttributeUserID(id int) attribute.KeyValue {
return attribute.Int("user.id", id)
}
// AttributeLanguage returns a tracing attribute for a language.
func AttributeLanguage(lang string) attribute.KeyValue {
return attribute.String("language", lang)
}
// AttributeLevel returns a tracing attribute for a level.
func AttributeLevel(level string) attribute.KeyValue {
return attribute.String("level", level)
}
// AttributeQuestionType returns a tracing attribute for a question type.
func AttributeQuestionType(qType interface{}) attribute.KeyValue {
return attribute.String("question.type", fmt.Sprintf("%v", qType))
}
// AttributeLimit returns a tracing attribute for a limit value.
func AttributeLimit(limit int) attribute.KeyValue {
return attribute.Int("limit", limit)
}
// AttributePage returns a tracing attribute for a page value.
func AttributePage(page int) attribute.KeyValue {
return attribute.Int("page", page)
}
// AttributePageSize returns a tracing attribute for a page size value.
func AttributePageSize(size int) attribute.KeyValue {
return attribute.Int("page_size", size)
}
// AttributeSearch returns a tracing attribute for a search value.
func AttributeSearch(search string) attribute.KeyValue {
return attribute.String("search", search)
}
// AttributeTypeFilter returns a tracing attribute for a type filter value.
func AttributeTypeFilter(typeFilter string) attribute.KeyValue {
return attribute.String("type_filter", typeFilter)
}
// AttributeStatusFilter returns a tracing attribute for a status filter value.
func AttributeStatusFilter(statusFilter string) attribute.KeyValue {
return attribute.String("status_filter", statusFilter)
}
// Package observability provides OpenTelemetry tracing, metrics, and structured logging
// with trace correlation for the quiz application.
package observability
import (
"context"
"os"
"quizapp/internal/config"
"go.opentelemetry.io/contrib/bridges/otelzap"
"go.opentelemetry.io/otel/exporters/otlp/otlplog/otlploggrpc"
"go.opentelemetry.io/otel/sdk/log"
"go.uber.org/zap"
"go.uber.org/zap/zapcore"
)
// Logger wraps the zap logger with OpenTelemetry context support
type Logger struct {
*zap.Logger
}
// NewLogger creates a new logger with OpenTelemetry context support and OTLP export
func NewLogger(cfg *config.OpenTelemetryConfig) *Logger {
return NewLoggerWithLevel(cfg, zap.InfoLevel)
}
// NewLoggerWithLevel creates a new logger with OpenTelemetry context support and OTLP export
func NewLoggerWithLevel(cfg *config.OpenTelemetryConfig, level zapcore.Level) *Logger {
// If logging is disabled, return a no-op logger
if cfg == nil || !cfg.EnableLogging {
return &Logger{Logger: zap.NewNop()}
}
// Create a basic zap logger for stdout
zapConfig := zap.NewProductionConfig()
zapConfig.Level = zap.NewAtomicLevelAt(level)
zapConfig.EncoderConfig.TimeKey = "timestamp"
zapConfig.EncoderConfig.EncodeTime = zapcore.ISO8601TimeEncoder
zapConfig.EncoderConfig.StacktraceKey = "stacktrace"
// Use development config if in development mode
if os.Getenv("ENV") == "development" {
zapConfig = zap.NewDevelopmentConfig()
zapConfig.Level = zap.NewAtomicLevelAt(level)
}
zapLogger, err := zapConfig.Build()
if err != nil {
// Fallback to a basic logger if config fails
zapLogger = zap.NewExample()
}
// If OTLP logging is enabled, set up the OTLP exporter
if cfg.EnableLogging && cfg.Endpoint != "" {
// Log that we're attempting to set up OTLP export
zapLogger.Info("Setting up OTLP logging", zap.String("endpoint", cfg.Endpoint), zap.String("protocol", cfg.Protocol))
// Create OTLP exporter with proper endpoint format
endpoint := cfg.Endpoint
exporter, err := otlploggrpc.New(context.Background(),
otlploggrpc.WithEndpoint(endpoint),
otlploggrpc.WithInsecure(),
)
if err != nil {
// Log the error but continue with stdout logging
zapLogger.Error("Failed to create OTLP exporter", zap.Error(err), zap.String("endpoint", endpoint))
} else {
zapLogger.Info("Successfully created OTLP exporter", zap.String("endpoint", endpoint))
// Create batch processor
processor := log.NewBatchProcessor(exporter)
// Create logger provider
provider := log.NewLoggerProvider(
log.WithProcessor(processor),
)
// Create OpenTelemetry core
otelCore := otelzap.NewCore("quizapp", otelzap.WithLoggerProvider(provider))
// Create a new zap logger with both stdout and OTLP cores
cores := []zapcore.Core{
zapLogger.Core(),
otelCore,
}
// Create a new logger with multiple cores
multiCore := zapcore.NewTee(cores...)
zapLogger = zap.New(multiCore)
zapLogger.Info("OTLP logging successfully configured", zap.String("endpoint", endpoint))
}
} else {
zapLogger.Info("OTLP logging not enabled", zap.Bool("enable_logging", cfg.EnableLogging), zap.String("endpoint", cfg.Endpoint))
}
return &Logger{Logger: zapLogger}
}
// Debug logs a debug message with context
func (l *Logger) Debug(ctx context.Context, msg string, fields ...map[string]interface{}) {
l.logWithContext(ctx, zap.DebugLevel, msg, fields...)
}
// Info logs an info message with context
func (l *Logger) Info(ctx context.Context, msg string, fields ...map[string]interface{}) {
l.logWithContext(ctx, zap.InfoLevel, msg, fields...)
}
// Warn logs a warning message with context
func (l *Logger) Warn(ctx context.Context, msg string, fields ...map[string]interface{}) {
l.logWithContext(ctx, zap.WarnLevel, msg, fields...)
}
// Error logs an error message with context
func (l *Logger) Error(ctx context.Context, msg string, err error, fields ...map[string]interface{}) {
// Merge fields with error information
allFields := l.mergeFields(fields...)
if err != nil {
allFields["error"] = err.Error()
}
l.logWithContext(ctx, zap.ErrorLevel, msg, allFields)
}
// logWithContext logs a message with OpenTelemetry context correlation
func (l *Logger) logWithContext(_ context.Context, level zapcore.Level, msg string, fields ...map[string]interface{}) {
// Merge all fields into a single map
allFields := l.mergeFields(fields...)
// Convert fields to zap fields
zapFields := make([]zap.Field, 0, len(allFields))
for k, v := range allFields {
zapFields = append(zapFields, zap.Any(k, v))
}
// Log with the appropriate level
switch level {
case zap.DebugLevel:
l.Logger.Debug(msg, zapFields...)
case zap.InfoLevel:
l.Logger.Info(msg, zapFields...)
case zap.WarnLevel:
l.Logger.Warn(msg, zapFields...)
case zap.ErrorLevel:
l.Logger.Error(msg, zapFields...)
default:
l.Logger.Info(msg, zapFields...)
}
}
// mergeFields merges multiple field maps into a single map
func (l *Logger) mergeFields(fields ...map[string]interface{}) map[string]interface{} {
if len(fields) == 0 {
return map[string]interface{}{}
}
if len(fields) == 1 {
// Handle nil field map
if fields[0] == nil {
return map[string]interface{}{}
}
return fields[0]
}
// Merge multiple field maps
merged := make(map[string]interface{})
for _, fieldMap := range fields {
// Skip nil field maps
if fieldMap == nil {
continue
}
for k, v := range fieldMap {
merged[k] = v
}
}
return merged
}
// Sync flushes any buffered log entries
func (l *Logger) Sync() error {
return l.Logger.Sync()
}
package observability
import (
"context"
"quizapp/internal/config"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetrichttp"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/resource"
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"
)
// InitMetrics initializes OpenTelemetry metrics
func InitMetrics(cfg *config.OpenTelemetryConfig) (result0 *metric.MeterProvider, err error) {
ctx := context.Background()
// Set up resource attributes
res, err := resource.New(ctx,
resource.WithAttributes(
semconv.ServiceName(cfg.ServiceName),
semconv.ServiceVersion(cfg.ServiceVersion),
),
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to create otel resource: %w", err)
}
// Set up exporter
var exporter metric.Exporter
switch cfg.Protocol {
case "grpc":
// For gRPC, strip http:// prefix if present, otherwise use endpoint as-is
endpoint := cfg.Endpoint
exp, err := otlpmetricgrpc.New(ctx,
otlpmetricgrpc.WithEndpoint(endpoint),
func() otlpmetricgrpc.Option {
if cfg.Insecure {
return otlpmetricgrpc.WithInsecure()
}
return nil
}(),
otlpmetricgrpc.WithHeaders(cfg.Headers),
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to create otlp grpc metric exporter: %w", err)
}
exporter = exp
case "http":
exp, err := otlpmetrichttp.New(ctx,
otlpmetrichttp.WithEndpoint(cfg.Endpoint),
func() otlpmetrichttp.Option {
if cfg.Insecure {
return otlpmetrichttp.WithInsecure()
}
return nil
}(),
otlpmetrichttp.WithHeaders(cfg.Headers),
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to create otlp http metric exporter: %w", err)
}
exporter = exp
default:
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "unsupported otel protocol: %s", cfg.Protocol)
}
// Set up meter provider
mp := metric.NewMeterProvider(
metric.WithReader(metric.NewPeriodicReader(exporter)),
metric.WithResource(res),
)
return mp, nil
}
package observability
import (
"errors"
"github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin"
"go.opentelemetry.io/contrib/instrumentation/github.com/gin-gonic/gin/otelgin"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
contextutils "quizapp/internal/utils"
)
// GinMiddleware creates OpenTelemetry middleware for Gin HTTP requests
func GinMiddleware(serviceName string) gin.HandlerFunc {
return otelgin.Middleware(serviceName)
}
// GinMiddlewareWithErrorHandling creates OpenTelemetry middleware with automatic error attribute addition and detailed logging
func GinMiddlewareWithErrorHandling(serviceName string) gin.HandlerFunc {
return func(c *gin.Context) {
// Use the existing OpenTelemetry middleware
otelgin.Middleware(serviceName)(c)
// After the request is processed, check for errors
c.Next()
// Get the span from context and add error attributes for failed requests
if span := trace.SpanFromContext(c.Request.Context()); span != nil {
statusCode := c.Writer.Status()
if statusCode >= 400 {
// Determine error severity based on status code and error types
severity := determineErrorSeverity(statusCode, c.Errors)
// Create a more descriptive error message based on status code
var errorMsg string
switch {
case statusCode >= 500:
errorMsg = "server error"
case statusCode >= 400:
errorMsg = "client error"
default:
errorMsg = "request failed"
}
// Add error details from Gin's error context if available
if len(c.Errors) > 0 {
for _, err := range c.Errors {
if appErr, ok := err.Err.(*contextutils.AppError); ok {
errorMsg = appErr.Message
severity = string(appErr.Severity)
break
}
errorMsg = err.Error()
}
}
// Record the error with stack trace
span.RecordError(errors.New(errorMsg), trace.WithStackTrace(true))
span.SetStatus(codes.Error, errorMsg)
// Add additional attributes for better debugging
span.SetAttributes(
attribute.Int("http.status_code", statusCode),
attribute.String("http.method", c.Request.Method),
attribute.String("http.path", c.Request.URL.Path),
attribute.String("error.handler", c.HandlerName()),
attribute.String("error.severity", severity),
)
// Add user context if available
session := sessions.Default(c)
if userID, ok := session.Get("user_id").(int); ok {
span.SetAttributes(attribute.Int("error.user_id", userID))
}
// Add request body size for debugging
if c.Request.ContentLength > 0 {
span.SetAttributes(attribute.Int64("error.request_size", c.Request.ContentLength))
}
// Add specific error attributes based on error types
if len(c.Errors) > 0 {
for _, err := range c.Errors {
if appErr, ok := err.Err.(*contextutils.AppError); ok {
span.SetAttributes(
attribute.String("error.code", string(appErr.Code)),
attribute.Bool("error.retryable", contextutils.IsRetryable(appErr)),
)
break
}
}
}
// Add server error specific attributes
if statusCode >= 500 {
span.SetAttributes(
attribute.Bool("error.server_error", true),
)
}
}
}
}
}
// determineErrorSeverity determines the severity level based on status code and error types
func determineErrorSeverity(statusCode int, errors []*gin.Error) string {
// Check for AppError types first
for _, err := range errors {
if appErr, ok := err.Err.(*contextutils.AppError); ok {
return string(appErr.Severity)
}
}
// Fallback to status code based severity
switch {
case statusCode >= 500:
return string(contextutils.SeverityError)
case statusCode >= 400:
return string(contextutils.SeverityWarn)
default:
return string(contextutils.SeverityInfo)
}
}
package observability
import (
"quizapp/internal/config"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/trace"
)
// SetupObservability initializes tracing, metrics, and logging for a service
func SetupObservability(cfg *config.OpenTelemetryConfig, serviceName string) (result0 *trace.TracerProvider, result1 *metric.MeterProvider, result2 *Logger, err error) {
if serviceName != "" {
cfg.ServiceName = serviceName
}
var tp *trace.TracerProvider
var mp *metric.MeterProvider
var logger *Logger
if cfg.EnableTracing {
tp, err = InitTracing(cfg)
if err != nil {
return nil, nil, nil, err
}
// Initialize the global tracer
InitGlobalTracer()
}
if cfg.EnableMetrics {
mp, err = InitMetrics(cfg)
if err != nil {
return tp, nil, nil, err
}
}
if cfg.EnableLogging {
logger = NewLogger(cfg)
} else {
// Return a no-op logger when logging is disabled
logger = NewLogger(&config.OpenTelemetryConfig{EnableLogging: false})
}
return tp, mp, logger, nil
}
package observability
import (
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// FinishSpan ends a span and records any error pointed to by errPtr.
// Use with a named error return: `defer observability.FinishSpan(span, &err)`
func FinishSpan(span trace.Span, errPtr *error) {
if span == nil {
return
}
if errPtr != nil && *errPtr != nil {
span.RecordError(*errPtr, trace.WithStackTrace(true))
span.SetStatus(codes.Error, (*errPtr).Error())
}
span.End()
}
package observability
import (
"context"
"quizapp/internal/config"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp"
"go.opentelemetry.io/otel/propagation"
"go.opentelemetry.io/otel/sdk/resource"
"go.opentelemetry.io/otel/sdk/trace"
semconv "go.opentelemetry.io/otel/semconv/v1.21.0"
)
// InitTracing initializes OpenTelemetry tracing
func InitTracing(cfg *config.OpenTelemetryConfig) (result0 *trace.TracerProvider, err error) {
ctx := context.Background()
// Set up resource attributes
res, err := resource.New(ctx,
resource.WithAttributes(
semconv.ServiceName(cfg.ServiceName),
semconv.ServiceVersion(cfg.ServiceVersion),
),
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to create otel resource: %w", err)
}
// Set up exporter
var exporter trace.SpanExporter
switch cfg.Protocol {
case "grpc":
// For gRPC, strip http:// prefix if present, otherwise use endpoint as-is
endpoint := cfg.Endpoint
exp, err := otlptracegrpc.New(ctx,
otlptracegrpc.WithEndpoint(endpoint),
func() otlptracegrpc.Option {
if cfg.Insecure {
return otlptracegrpc.WithInsecure()
}
return nil
}(),
otlptracegrpc.WithHeaders(cfg.Headers),
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to create otlp grpc exporter: %w", err)
}
exporter = exp
case "http":
exp, err := otlptracehttp.New(ctx,
otlptracehttp.WithEndpoint(cfg.Endpoint),
otlptracehttp.WithInsecure(),
otlptracehttp.WithHeaders(cfg.Headers),
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to create otlp http exporter: %w", err)
}
exporter = exp
default:
return nil, contextutils.WrapErrorf(contextutils.ErrInternalError, "unsupported otel protocol: %s", cfg.Protocol)
}
// Set up sampler
sampler := trace.ParentBased(trace.TraceIDRatioBased(cfg.SamplingRate))
// Set up tracer provider
tp := trace.NewTracerProvider(
trace.WithBatcher(exporter),
trace.WithResource(res),
trace.WithSampler(sampler),
)
otel.SetTracerProvider(tp)
// Set up text map propagator for trace context propagation
// This enables the backend to receive and process trace headers from NGINX
otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(
propagation.TraceContext{},
propagation.Baggage{},
))
return tp, nil
}
// Package services provides business logic services for the quiz application.
package services
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"runtime/debug"
"strconv"
"strings"
"sync"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"github.com/xeipuuv/gojsonschema"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// JSON Schema definitions for grammar field
// These schemas are used with the 'grammar' field in OpenAI-compatible API requests
// to enforce specific JSON structure validation. This ensures that AI models return
// exactly the expected format, eliminating parsing errors and improving reliability.
//
// The grammar field is conditionally included based on provider support (see supportsGrammarField).
// Providers that don't support grammar (like Google) will fall back to prompt-based structure guidance.
const (
// Single-item schemas for ai-fix (single question objects)
SingleQuestionSchema = `{
"type": "object",
"properties": {
"question": {"type": "string"},
"options": {"type": "array", "items": {"type": "string"}, "minItems": 4, "maxItems": 4},
"correct_answer": {"type": "integer"},
"explanation": {"type": "string"},
"topic": {"type": "string"}
},
"required": ["question", "options", "correct_answer", "explanation"]
}`
SingleReadingComprehensionSchema = `{
"type": "object",
"properties": {
"passage": {"type": "string"},
"question": {"type": "string"},
"options": {"type": "array", "items": {"type": "string"}, "minItems": 4, "maxItems": 4},
"correct_answer": {"type": "integer"},
"explanation": {"type": "string"},
"topic": {"type": "string"}
},
"required": ["passage", "question", "options", "correct_answer", "explanation"]
}`
SingleVocabularyQuestionSchema = `{
"type": "object",
"properties": {
"sentence": {"type": "string"},
"question": {"type": "string"},
"options": {"type": "array", "items": {"type": "string"}, "minItems": 4, "maxItems": 4},
"correct_answer": {"type": "integer"},
"explanation": {"type": "string"},
"topic": {"type": "string"}
},
"required": ["sentence", "question", "options", "correct_answer", "explanation"]
}`
)
var (
// BatchQuestionsSchema is a batch wrapper around SingleQuestionSchema.
BatchQuestionsSchema = fmt.Sprintf(`{"type":"array","items":%s}`, SingleQuestionSchema)
// BatchReadingComprehensionSchema is a batch wrapper around SingleReadingComprehensionSchema.
BatchReadingComprehensionSchema = fmt.Sprintf(`{"type":"array","items":%s}`, SingleReadingComprehensionSchema)
// BatchVocabularyQuestionSchema is a batch wrapper around SingleVocabularyQuestionSchema.
BatchVocabularyQuestionSchema = fmt.Sprintf(`{"type":"array","items":%s}`, SingleVocabularyQuestionSchema)
)
// UserAIConfig holds per-user AI configuration
type UserAIConfig struct {
Provider string
Model string
APIKey string
Username string // For logging purposes
}
// AIServiceInterface defines the interface for AI-powered question generation
type AIServiceInterface interface {
GenerateQuestion(ctx context.Context, userConfig *UserAIConfig, req *models.AIQuestionGenRequest) (*models.Question, error)
GenerateQuestions(ctx context.Context, userConfig *UserAIConfig, req *models.AIQuestionGenRequest) ([]*models.Question, error)
GenerateQuestionsStream(ctx context.Context, userConfig *UserAIConfig, req *models.AIQuestionGenRequest, progress chan<- *models.Question, variety *VarietyElements) error
GenerateChatResponse(ctx context.Context, userConfig *UserAIConfig, req *models.AIChatRequest) (string, error)
GenerateChatResponseStream(ctx context.Context, userConfig *UserAIConfig, req *models.AIChatRequest, chunks chan<- string) error
TestConnection(ctx context.Context, provider, model, apiKey string) error
GetConcurrencyStats() ConcurrencyStats
GetQuestionBatchSize(provider string) int
VarietyService() *VarietyService
// TemplateManager exposes template rendering and example loading for prompts
TemplateManager() *AITemplateManager
// SupportsGrammarField reports whether the provider supports the grammar field
SupportsGrammarField(provider string) bool
// CallWithPrompt sends a raw prompt (and optional grammar) to the provider and returns the response
CallWithPrompt(ctx context.Context, userConfig *UserAIConfig, prompt, grammar string) (string, error)
Shutdown(ctx context.Context) error
}
// ConcurrencyStats provides metrics about AI request concurrency
type ConcurrencyStats struct {
ActiveRequests int `json:"active_requests"`
MaxConcurrent int `json:"max_concurrent"`
QueuedRequests int `json:"queued_requests"`
TotalRequests int64 `json:"total_requests"`
UserActiveCount map[string]int `json:"user_active_count"`
MaxPerUser int `json:"max_per_user"`
}
// AIService provides AI-powered question generation using OpenAI-compatible APIs
type AIService struct {
httpClient *http.Client
debug bool
cfg *config.Config
// Template management
templateManager *AITemplateManager
// Variety service for question diversity
varietyService *VarietyService
// Concurrency control
globalSemaphore chan struct{} // Limits total concurrent requests
maxConcurrent int // Maximum concurrent requests globally
maxPerUser int // Maximum concurrent requests per user
// Per-user concurrency tracking
userRequestCount map[string]int // Username -> active request count
concurrencyMu sync.RWMutex // Protects user maps
// Metrics
totalRequests int64 // Total requests processed
activeRequests int // Current active requests
statsMu sync.RWMutex // Protects stats
// Observability
logger *observability.Logger
// Shutdown control
shutdownCtx context.Context
shutdownMu sync.RWMutex
}
// Schema validation counters
var (
SchemaValidationFailures = make(map[models.QuestionType]int)
SchemaValidationFailureDetails = make(map[models.QuestionType][]string) // NEW: error details
SchemaValidationMu sync.Mutex
)
// extractItemsSchema extracts the items schema from a batch schema
func extractItemsSchema(batchSchema string) (result0 string, err error) {
var schemaMap map[string]interface{}
if err = json.Unmarshal([]byte(batchSchema), &schemaMap); err != nil {
return "", err
}
// For batch schemas, extract the items schema
if items, ok := schemaMap["items"]; ok {
var itemsBytes []byte
itemsBytes, err = json.Marshal(items)
if err != nil {
return "", err
}
return string(itemsBytes), nil
}
return "", contextutils.ErrorWithContextf("no items found in batch schema")
}
// ValidateQuestionSchema validates a question against the appropriate schema
func (s *AIService) ValidateQuestionSchema(ctx context.Context, qType models.QuestionType, question interface{}) (result0 bool, err error) {
_, span := observability.TraceAIFunction(ctx, "validate_question_schema",
observability.AttributeQuestionType(qType),
)
defer observability.FinishSpan(span, &err)
// Validate input parameters
if question == nil {
span.SetAttributes(attribute.String("validation.result", "nil_question"))
return false, contextutils.ErrorWithContextf("question cannot be nil")
}
var schema string
switch qType {
case models.Vocabulary:
schema = BatchVocabularyQuestionSchema
case models.ReadingComprehension:
schema = BatchReadingComprehensionSchema
case models.FillInBlank, models.QuestionAnswer:
schema = BatchQuestionsSchema
default:
span.SetAttributes(attribute.String("validation.result", "unknown_type"))
return false, contextutils.ErrorWithContextf("unknown question type: %v", qType)
}
// Extract the items schema for validation
itemSchema, err := extractItemsSchema(schema)
if err != nil {
span.SetAttributes(attribute.String("validation.result", "schema_extract_error"), attribute.String("validation.error", err.Error()))
return false, contextutils.WrapErrorf(err, "failed to extract schema for question type %v", qType)
}
// Marshal the question to JSON
// If question is a *models.Question, validate only Content
toValidate := question
if q, ok := question.(*models.Question); ok {
if q == nil {
span.SetAttributes(attribute.String("validation.result", "nil_question_model"))
return false, contextutils.ErrorWithContextf("question model is nil")
}
toValidate = q.Content
}
questionBytes, err := json.Marshal(toValidate)
if err != nil {
span.SetAttributes(attribute.String("validation.result", "marshal_error"), attribute.String("validation.error", err.Error()))
return false, contextutils.WrapErrorf(err, "failed to marshal question for validation")
}
// Validate
result, err := gojsonschema.Validate(
gojsonschema.NewStringLoader(itemSchema),
gojsonschema.NewBytesLoader(questionBytes),
)
if err != nil {
span.SetAttributes(attribute.String("validation.result", "validate_error"), attribute.String("validation.error", err.Error()))
return false, contextutils.WrapErrorf(err, "schema validation failed for question type %v", qType)
}
if !result.Valid() {
errs := result.Errors()
var errorMessages []string
for _, e := range errs {
errorMessages = append(errorMessages, e.String())
}
span.SetAttributes(attribute.String("validation.result", "invalid"))
return false, contextutils.ErrorWithContextf("question failed schema validation: %s", strings.Join(errorMessages, "; "))
}
span.SetAttributes(attribute.String("validation.result", "valid"))
return true, nil
}
// NewAIService creates a new AI service instance
func NewAIService(cfg *config.Config, logger *observability.Logger) *AIService {
// Create template manager
templateManager, err := NewAITemplateManager()
if err != nil {
logger.Error(context.Background(), "Failed to create template manager", err, map[string]interface{}{})
panic(err) // Use panic for fatal errors in initialization
}
// Create variety service
varietyService := NewVarietyServiceWithLogger(cfg, logger)
// Create instrumented HTTP client with reasonable timeouts and explicit span options
httpClient := &http.Client{
Timeout: config.AIRequestTimeout,
Transport: otelhttp.NewTransport(http.DefaultTransport,
otelhttp.WithSpanOptions(trace.WithSpanKind(trace.SpanKindClient)),
),
}
// Get concurrency limits from config
maxConcurrent := cfg.Server.MaxAIConcurrent
maxPerUser := cfg.Server.MaxAIPerUser
// Create global semaphore for limiting concurrent requests
globalSemaphore := make(chan struct{}, maxConcurrent)
service := &AIService{
httpClient: httpClient,
debug: cfg.Server.Debug,
cfg: cfg,
templateManager: templateManager,
varietyService: varietyService,
globalSemaphore: globalSemaphore,
maxConcurrent: maxConcurrent,
maxPerUser: maxPerUser,
userRequestCount: make(map[string]int),
shutdownCtx: context.Background(),
logger: logger,
}
return service
}
// Shutdown gracefully shuts down the AI service and cleans up resources
func (s *AIService) Shutdown(ctx context.Context) error {
s.shutdownMu.Lock()
defer s.shutdownMu.Unlock()
// Create a new shutdown context
shutdownCtx, cancel := context.WithCancel(ctx)
s.shutdownCtx = shutdownCtx
defer cancel()
// Wait for all active requests to complete with timeout
timeout := config.AIShutdownTimeout
if deadline, ok := ctx.Deadline(); ok {
timeout = time.Until(deadline)
}
// Wait for active requests to complete
ticker := time.NewTicker(config.AIShutdownPollInterval)
defer ticker.Stop()
for i := 0; i < int(timeout/config.AIShutdownPollInterval); i++ {
s.statsMu.RLock()
active := s.activeRequests
s.statsMu.RUnlock()
if active == 0 {
break
}
select {
case <-ticker.C:
continue
case <-ctx.Done():
return ctx.Err()
}
}
// Close the HTTP client
if s.httpClient != nil {
s.httpClient.CloseIdleConnections()
}
// Clean up user request counts
s.concurrencyMu.Lock()
s.userRequestCount = make(map[string]int)
s.concurrencyMu.Unlock()
s.logger.Info(ctx, "AI Service shutdown completed")
return nil
}
// isShutdown checks if the service is shutting down
func (s *AIService) isShutdown() bool {
s.shutdownMu.RLock()
defer s.shutdownMu.RUnlock()
select {
case <-s.shutdownCtx.Done():
return true
default:
return false
}
}
// OpenAIRequest represents a request to the OpenAI-compatible API
type OpenAIRequest struct {
Model string `json:"model"`
Messages []Message `json:"messages"`
Temperature float64 `json:"temperature"`
MaxTokens int `json:"max_tokens"`
Grammar string `json:"grammar,omitempty"`
Stream bool `json:"stream,omitempty"`
}
// Message represents a chat message in the API request
type Message struct {
Role string `json:"role"`
Content string `json:"content"`
}
// OpenAIResponse represents a response from the OpenAI-compatible API
type OpenAIResponse struct {
Choices []Choice `json:"choices"`
Error *APIError `json:"error,omitempty"`
}
// Choice represents a choice in the API response
type Choice struct {
Message Message `json:"message"`
}
// APIError represents an error response from the API
type APIError struct {
Message string `json:"message"`
Type string `json:"type"`
}
// OpenAIStreamResponse represents a streaming response chunk from the OpenAI-compatible API
type OpenAIStreamResponse struct {
Choices []StreamChoice `json:"choices"`
Error *APIError `json:"error,omitempty"`
}
// StreamChoice represents a choice in the streaming API response
type StreamChoice struct {
Delta StreamDelta `json:"delta"`
FinishReason *string `json:"finish_reason"`
}
// StreamDelta represents the delta content in a streaming response
type StreamDelta struct {
Content string `json:"content"`
}
// getGrammarSchema returns the appropriate JSON schema for the given question type
func getGrammarSchema(questionType models.QuestionType) string {
// Always return the batch schema for each type
switch questionType {
case models.ReadingComprehension:
return BatchReadingComprehensionSchema
case models.Vocabulary:
return BatchVocabularyQuestionSchema
case models.FillInBlank:
return BatchQuestionsSchema
case models.QuestionAnswer:
return BatchQuestionsSchema
}
// Fallback for unknown types
return BatchQuestionsSchema
}
// GetFixSchema returns the single-item JSON schema for ai-fix or an error if unsupported.
func GetFixSchema(questionType models.QuestionType) (string, error) {
switch questionType {
case models.ReadingComprehension:
return SingleReadingComprehensionSchema, nil
case models.Vocabulary:
return SingleVocabularyQuestionSchema, nil
case models.FillInBlank, models.QuestionAnswer:
return SingleQuestionSchema, nil
default:
return "", contextutils.WrapErrorf(contextutils.ErrAIConfigInvalid, "no schema for question type: %v", questionType)
}
}
// addJSONStructureGuidance appends JSON structure requirements to prompts for providers that don't support grammar
func (s *AIService) addJSONStructureGuidance(prompt string, questionType models.QuestionType) string {
// Get the schema for this question type
schema := getGrammarSchema(questionType)
data := AITemplateData{
SchemaForPrompt: schema,
}
guidance, err := s.templateManager.RenderTemplate(JSONStructureGuidanceTemplate, data)
if err != nil {
s.logger.Error(context.Background(), "Failed to render JSON structure guidance template", err, map[string]interface{}{})
panic(err)
}
return prompt + guidance
}
// GenerateQuestion generates a single question using AI
func (s *AIService) GenerateQuestion(ctx context.Context, userConfig *UserAIConfig, req *models.AIQuestionGenRequest) (result0 *models.Question, err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_question",
attribute.String("user.username", userConfig.Username),
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
observability.AttributeQuestionType(string(req.QuestionType)),
)
defer observability.FinishSpan(span, &err)
// Check if the provider supports grammar field
supportsGrammar := s.supportsGrammarField(userConfig.Provider)
var prompt string
var grammar string
if supportsGrammar {
// Use batch prompt with count=1 for single question
prompt = s.buildBatchQuestionPrompt(ctx, req, nil)
grammar = getGrammarSchema(req.QuestionType)
} else {
// Use batch prompt with JSON structure guidance embedded
prompt = s.buildBatchQuestionPromptWithJSONStructure(ctx, req, nil)
grammar = "" // No grammar field for providers that don't support it
}
response, err := s.callOpenAI(ctx, userConfig, prompt, grammar)
if err != nil {
return nil, err
}
question, err := s.parseQuestionResponse(ctx, response, req.Language, req.Level, req.QuestionType, userConfig.Provider)
if err != nil {
return nil, err
}
return question, nil
}
// GenerateQuestions generates multiple questions in a single batch request
func (s *AIService) GenerateQuestions(ctx context.Context, userConfig *UserAIConfig, req *models.AIQuestionGenRequest) (result0 []*models.Question, err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_questions",
attribute.String("user.username", userConfig.Username),
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
observability.AttributeQuestionType(string(req.QuestionType)),
observability.AttributeLimit(req.Count),
)
defer observability.FinishSpan(span, &err)
// Check if the provider supports grammar field
supportsGrammar := s.supportsGrammarField(userConfig.Provider)
var prompt string
var grammar string
if supportsGrammar {
// Use regular prompt with grammar field
prompt = s.buildBatchQuestionPrompt(ctx, req, nil)
grammar = getGrammarSchema(req.QuestionType)
} else {
// Use prompt with JSON structure guidance embedded
prompt = s.buildBatchQuestionPromptWithJSONStructure(ctx, req, nil)
grammar = "" // No grammar field for providers that don't support it
}
response, err := s.callOpenAI(ctx, userConfig, prompt, grammar)
if err != nil {
return nil, err
}
questions, err := s.parseQuestionsResponse(ctx, response, req.Language, req.Level, req.QuestionType, userConfig.Provider)
if err != nil {
return nil, err
}
return questions, nil
}
// GenerateQuestionsStream generates questions and streams them via a channel, using the provided variety elements
func (s *AIService) GenerateQuestionsStream(ctx context.Context, userConfig *UserAIConfig, req *models.AIQuestionGenRequest, progress chan<- *models.Question, variety *VarietyElements) (err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_questions_stream",
attribute.String("user.username", userConfig.Username),
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
observability.AttributeQuestionType(string(req.QuestionType)),
observability.AttributeLimit(req.Count),
)
defer observability.FinishSpan(span, &err)
defer close(progress)
return s.withConcurrencyControl(ctx, userConfig.Username, func() error {
// Get the batch size for this provider
batchSize := s.getQuestionBatchSize(userConfig.Provider)
// Use batch generation for multiple questions
return s.generateQuestionsInBatchesWithVariety(ctx, userConfig, req, progress, batchSize, variety)
})
}
// generateQuestionsInBatchesWithVariety generates questions in batches for efficiency, using the provided variety elements
func (s *AIService) generateQuestionsInBatchesWithVariety(ctx context.Context, userConfig *UserAIConfig, req *models.AIQuestionGenRequest, progress chan<- *models.Question, batchSize int, variety *VarietyElements) (err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_questions_in_batches_with_variety",
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
observability.AttributeQuestionType(req.QuestionType),
observability.AttributeLanguage(req.Language),
observability.AttributeLevel(req.Level),
attribute.Int("batch_size", batchSize),
attribute.Int("total_count", req.Count),
attribute.Bool("variety.enabled", variety != nil),
)
defer observability.FinishSpan(span, &err)
// Local copy of history to be updated as we generate questions
localHistory := make([]string, len(req.RecentQuestionHistory))
copy(localHistory, req.RecentQuestionHistory)
remaining := req.Count
generated := 0
for remaining > 0 {
// Check for context cancellation
select {
case <-ctx.Done():
return ctx.Err()
default:
}
// Calculate how many questions to generate in this batch
currentBatchSize := min(remaining, batchSize)
// Create a batch request
batchReq := &models.AIQuestionGenRequest{
Language: req.Language,
Level: req.Level,
QuestionType: req.QuestionType,
Count: currentBatchSize,
RecentQuestionHistory: localHistory,
}
// Generate questions in batch using the provided variety elements
questions, err := s.generateQuestionsWithVariety(ctx, userConfig, batchReq, variety)
if err != nil {
return contextutils.WrapErrorf(err, "failed to generate batch of %d questions for user %s", currentBatchSize, userConfig.Username)
}
// Stream the generated questions
for _, question := range questions {
// Add generated question content to history for next iterations
if qContent, ok := question.Content["question"]; ok {
if qStr, ok := qContent.(string); ok {
localHistory = append(localHistory, qStr)
}
}
progress <- question
generated++
}
remaining -= len(questions)
}
return nil
}
// generateQuestionsWithVariety generates a batch of questions using the provided variety elements
func (s *AIService) generateQuestionsWithVariety(ctx context.Context, userConfig *UserAIConfig, req *models.AIQuestionGenRequest, variety *VarietyElements) (result0 []*models.Question, err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_questions_with_variety",
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
observability.AttributeQuestionType(req.QuestionType),
observability.AttributeLanguage(req.Language),
observability.AttributeLevel(req.Level),
attribute.Int("count", req.Count),
attribute.Bool("variety.enabled", variety != nil),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if the provider supports grammar field
supportsGrammar := s.supportsGrammarField(userConfig.Provider)
var prompt string
var grammar string
if supportsGrammar {
prompt = s.buildBatchQuestionPrompt(ctx, req, variety)
grammar = getGrammarSchema(req.QuestionType)
} else {
prompt = s.buildBatchQuestionPromptWithJSONStructure(ctx, req, variety)
grammar = ""
}
response, err := s.callOpenAI(ctx, userConfig, prompt, grammar)
if err != nil {
return nil, err
}
questions, err := s.parseQuestionsResponse(ctx, response, req.Language, req.Level, req.QuestionType, userConfig.Provider)
if err != nil {
return nil, err
}
return questions, nil
}
// GenerateChatResponse generates a chat response using AI
func (s *AIService) GenerateChatResponse(ctx context.Context, userConfig *UserAIConfig, req *models.AIChatRequest) (result0 string, err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_chat_response",
attribute.String("user.username", userConfig.Username),
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
)
defer observability.FinishSpan(span, &err)
var result string
var resultErr error
err = s.withConcurrencyControl(ctx, userConfig.Username, func() error {
prompt := s.buildChatPrompt(req)
// No grammar constraint for open-ended chat
result, resultErr = s.callOpenAI(ctx, userConfig, prompt, "")
return resultErr
})
if err != nil {
return "", err
}
return result, resultErr
}
// GenerateChatResponseStream generates a streaming chat response using AI
func (s *AIService) GenerateChatResponseStream(ctx context.Context, userConfig *UserAIConfig, req *models.AIChatRequest, chunks chan<- string) (err error) {
ctx, span := observability.TraceAIFunction(ctx, "generate_chat_response_stream",
attribute.String("user.username", userConfig.Username),
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
)
defer observability.FinishSpan(span, &err)
// Don't close the channel here - let the caller handle it to avoid race conditions
return s.withConcurrencyControl(ctx, userConfig.Username, func() error {
prompt := s.buildChatPrompt(req)
// No grammar constraint for open-ended chat
return s.callOpenAIStream(ctx, userConfig, prompt, "", chunks)
})
}
// TestConnection tests the connection to the AI service
func (s *AIService) TestConnection(ctx context.Context, provider, model, apiKey string) (err error) {
_, span := observability.TraceAIFunction(ctx, "test_connection",
attribute.String("ai.provider", provider),
attribute.String("ai.model", model),
)
defer observability.FinishSpan(span, &err)
// Validate input parameters
if provider == "" {
span.SetAttributes(attribute.String("test.result", "empty_provider"))
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "provider is required for testing connection")
}
if model == "" {
span.SetAttributes(attribute.String("test.result", "empty_model"))
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "model is required for testing connection")
}
s.logger.Debug(ctx, "TestConnection called", map[string]interface{}{
"provider": provider,
"model": model,
"apiKey": contextutils.MaskAPIKey(apiKey),
})
// Require API key for all providers that are not Ollama
if provider != "ollama" && apiKey == "" {
span.SetAttributes(attribute.String("test.result", "missing_api_key"), attribute.String("provider", provider))
return contextutils.WrapErrorf(contextutils.ErrAIConfigInvalid, "API key is required for testing connection with provider '%s'", provider)
}
// Create a simple test configuration
userConfig := &UserAIConfig{
Provider: provider,
Model: model,
APIKey: apiKey,
Username: "test-user",
}
s.logger.Debug(ctx, "Created userConfig", map[string]interface{}{
"provider": userConfig.Provider,
"model": userConfig.Model,
})
// Create a simple test request
testPrompt := "Respond with exactly the word 'SUCCESS' and nothing else."
// Create a timeout context for the test
testCtx, cancel := context.WithTimeout(ctx, config.AIRequestTimeout)
defer cancel()
// Test the actual AI service call
response, err := s.callOpenAI(testCtx, userConfig, testPrompt, "")
if err != nil {
span.SetAttributes(attribute.String("test.result", "call_failed"), attribute.String("error", err.Error()))
return contextutils.WrapErrorf(err, "connection test failed for provider '%s' with model '%s'", provider, model)
}
// Check if we got a reasonable response
if response == "" {
span.SetAttributes(attribute.String("test.result", "empty_response"))
return contextutils.WrapError(contextutils.ErrAIResponseInvalid, "connection test failed: received empty response from AI service")
}
// Validate that the response contains something meaningful
if len(response) < 3 {
span.SetAttributes(attribute.String("test.result", "response_too_short"), attribute.Int("response_length", len(response)))
return contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "connection test failed: response too short (%d characters)", len(response))
}
// The response should contain something meaningful
s.logger.Info(ctx, "TestConnection successful", map[string]interface{}{
"provider": provider,
"response_length": len(response),
})
span.SetAttributes(attribute.String("test.result", "success"), attribute.Int("response_length", len(response)))
return nil
}
// buildBatchQuestionPromptWithJSONStructure now takes variety elements
func (s *AIService) buildBatchQuestionPromptWithJSONStructure(ctx context.Context, req *models.AIQuestionGenRequest, variety *VarietyElements) string {
prompt := s.buildBatchQuestionPrompt(ctx, req, variety)
return s.addJSONStructureGuidance(prompt, req.QuestionType)
}
// buildBatchQuestionPrompt now takes variety elements
func (s *AIService) buildBatchQuestionPrompt(ctx context.Context, req *models.AIQuestionGenRequest, variety *VarietyElements) string {
_, span := observability.TraceAIFunction(ctx, "build_batch_question_prompt",
observability.AttributeQuestionType(req.QuestionType),
observability.AttributeLanguage(req.Language),
observability.AttributeLevel(req.Level),
attribute.Int("count", req.Count),
attribute.Bool("variety.enabled", variety != nil),
)
defer span.End()
tmplData := AITemplateData{
SchemaForPrompt: getGrammarSchema(req.QuestionType),
Language: req.Language,
Level: req.Level,
QuestionType: string(req.QuestionType),
Count: req.Count,
RecentQuestionHistory: req.RecentQuestionHistory,
}
if variety != nil {
tmplData.TopicCategory = variety.TopicCategory
tmplData.GrammarFocus = variety.GrammarFocus
tmplData.VocabularyDomain = variety.VocabularyDomain
tmplData.Scenario = variety.Scenario
tmplData.StyleModifier = variety.StyleModifier
tmplData.DifficultyModifier = variety.DifficultyModifier
tmplData.TimeContext = variety.TimeContext
}
// Priority data is handled by the worker, not passed to AI service
// Load example for this question type
if exampleContent, err := s.templateManager.LoadExample(string(req.QuestionType)); err == nil {
tmplData.ExampleContent = exampleContent
}
prompt, err := s.templateManager.RenderTemplate(BatchQuestionPromptTemplate, tmplData)
if err != nil {
s.logger.Error(ctx, "Failed to render batch question prompt template", err, map[string]interface{}{})
panic(err) // Use panic for fatal errors in template rendering
}
return prompt
}
func (s *AIService) buildChatPrompt(req *models.AIChatRequest) string {
// Convert conversation history to template format
var conversationHistory []ChatMessage
for _, msg := range req.ConversationHistory {
conversationHistory = append(conversationHistory, ChatMessage{
Role: string(msg.Role),
Content: msg.Content,
})
}
data := AITemplateData{
Language: req.Language,
Level: req.Level,
QuestionType: string(req.QuestionType),
Passage: req.Passage,
Question: req.Question,
Options: req.Options,
IsCorrect: req.IsCorrect,
ConversationHistory: conversationHistory,
UserMessage: req.UserMessage,
}
prompt, err := s.templateManager.RenderTemplate(ChatPromptTemplate, data)
if err != nil {
s.logger.Error(context.Background(), "Failed to render chat prompt template", err, map[string]interface{}{})
panic(err) // Use panic for fatal errors in template rendering
}
return prompt
}
// getMaxTokensForModel looks up the max_tokens for a specific provider and model
func (s *AIService) getMaxTokensForModel(provider, model string) int {
// Look up the model in the provider configuration
if s.cfg.Providers != nil {
for _, providerConfig := range s.cfg.Providers {
if providerConfig.Code == provider {
for _, modelConfig := range providerConfig.Models {
if modelConfig.Code == model {
if modelConfig.MaxTokens > 0 {
return modelConfig.MaxTokens
}
break
}
}
break
}
}
}
// Default fallback
return 4000
}
// callOpenAI makes a request to the OpenAI-compatible API
func (s *AIService) callOpenAI(ctx context.Context, userConfig *UserAIConfig, prompt, grammar string) (result0 string, err error) {
if userConfig == nil {
return "", contextutils.WrapError(contextutils.ErrAIConfigInvalid, "userConfig is required")
}
_, span := observability.TraceAIFunction(ctx, "call_openai",
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
attribute.String("ai.username", userConfig.Username),
attribute.Int("prompt.length", len(prompt)),
attribute.Bool("grammar.enabled", grammar != ""),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Validate input parameters
if userConfig.Provider == "" {
span.SetAttributes(attribute.String("call.result", "empty_provider"))
return "", contextutils.WrapError(contextutils.ErrAIConfigInvalid, "provider is required")
}
if userConfig.Model == "" {
span.SetAttributes(attribute.String("call.result", "empty_model"))
return "", contextutils.WrapError(contextutils.ErrAIConfigInvalid, "model is required")
}
if prompt == "" {
span.SetAttributes(attribute.String("call.result", "empty_prompt"))
return "", contextutils.WrapError(contextutils.ErrAIConfigInvalid, "prompt cannot be empty")
}
apiURL := ""
model := userConfig.Model
apiKey := userConfig.APIKey
// Look up the default URL from provider config
if s.cfg.Providers != nil {
for _, providerConfig := range s.cfg.Providers {
if providerConfig.Code == userConfig.Provider && providerConfig.URL != "" {
apiURL = providerConfig.URL
break
}
}
}
if apiURL == "" {
span.SetAttributes(attribute.String("call.result", "no_url_configured"), attribute.String("provider", userConfig.Provider))
return "", contextutils.WrapErrorf(contextutils.ErrAIConfigInvalid, "no base URL configured for provider '%s'", userConfig.Provider)
}
userPrefix := ""
if userConfig.Username != "" {
userPrefix = fmt.Sprintf("[user=%s] ", userConfig.Username)
}
s.logger.Debug(ctx, "Starting AI request", map[string]interface{}{
"user_prefix": userPrefix,
"url": apiURL + "/chat/completions",
"model": model,
"provider": userConfig.Provider,
})
// Create messages with just the user prompt - grammar field will enforce JSON structure
messages := []Message{{Role: "user", Content: prompt}}
// Check if the provider supports grammar field
supportsGrammar := s.supportsGrammarField(userConfig.Provider)
reqBody := OpenAIRequest{
Model: model,
Messages: messages,
Temperature: 0.7,
MaxTokens: s.getMaxTokensForModel(userConfig.Provider, userConfig.Model),
}
// Only include grammar field if the provider supports it
if supportsGrammar && grammar != "" {
reqBody.Grammar = grammar
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
s.logger.Error(ctx, "Failed to marshal AI request", err, map[string]interface{}{
"user_prefix": userPrefix,
})
span.SetAttributes(attribute.String("call.result", "marshal_failed"), attribute.String("error", err.Error()))
return "", contextutils.WrapErrorf(err, "failed to marshal request body")
}
s.logger.Debug(ctx, "Making AI HTTP request", map[string]interface{}{
"user_prefix": userPrefix,
"url": apiURL + "/chat/completions",
})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
s.logger.Error(ctx, "Failed to create AI HTTP request", err, map[string]interface{}{
"user_prefix": userPrefix,
})
span.SetAttributes(attribute.String("call.result", "request_creation_failed"), attribute.String("error", err.Error()))
return "", contextutils.WrapErrorf(err, "failed to create HTTP request")
}
req.Header.Set("Content-Type", "application/json")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
s.logger.Debug(ctx, "Using API key authentication", map[string]interface{}{
"user_prefix": userPrefix,
})
} else {
s.logger.Debug(ctx, "No API key provided, using anonymous access", map[string]interface{}{
"user_prefix": userPrefix,
})
}
startTime := time.Now()
resp, err := s.httpClient.Do(req.WithContext(ctx))
duration := time.Since(startTime)
if err != nil {
s.logger.Error(ctx, "AI HTTP request failed", err, map[string]interface{}{
"user_prefix": userPrefix,
"duration": duration.String(),
})
span.SetAttributes(attribute.String("call.result", "http_request_failed"), attribute.String("error", err.Error()), attribute.String("duration", duration.String()))
return "", contextutils.WrapErrorf(err, "HTTP request failed after %v", duration)
}
defer func() {
if err := resp.Body.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{
"error": err.Error(),
})
}
}()
s.logger.Info(ctx, "AI Service HTTP request completed", map[string]interface{}{
"user_prefix": userPrefix,
"duration": duration.String(),
"status_code": resp.StatusCode,
})
body, err := io.ReadAll(resp.Body)
if err != nil {
span.SetAttributes(attribute.String("call.result", "body_read_failed"), attribute.String("error", err.Error()))
return "", contextutils.WrapErrorf(err, "failed to read response body")
}
if resp.StatusCode != http.StatusOK {
span.SetAttributes(attribute.String("call.result", "http_error"), attribute.Int("status_code", resp.StatusCode), attribute.String("body", string(body)))
return "", contextutils.WrapErrorf(contextutils.ErrAIRequestFailed, "API request failed with status %d to %s: %s", resp.StatusCode, apiURL+"/chat/completions", string(body))
}
var openAIResp OpenAIResponse
if err := json.Unmarshal(body, &openAIResp); err != nil {
span.SetAttributes(attribute.String("call.result", "json_unmarshal_failed"), attribute.String("error", err.Error()), attribute.String("body", string(body)))
return "", contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "failed to parse AI response as JSON: %w. Raw Response: %s", err, string(body))
}
if openAIResp.Error != nil {
span.SetAttributes(attribute.String("call.result", "api_error"), attribute.String("error_message", openAIResp.Error.Message), attribute.String("error_type", openAIResp.Error.Type))
return "", contextutils.WrapErrorf(contextutils.ErrAIRequestFailed, "OpenAI API error: %s", openAIResp.Error.Message)
}
if len(openAIResp.Choices) == 0 {
span.SetAttributes(attribute.String("call.result", "no_choices"))
return "", contextutils.WrapError(contextutils.ErrAIResponseInvalid, "no response from OpenAI")
}
content := openAIResp.Choices[0].Message.Content
if content == "" {
span.SetAttributes(attribute.String("call.result", "empty_content"))
return "", contextutils.WrapError(contextutils.ErrAIResponseInvalid, "AI returned empty content")
}
span.SetAttributes(attribute.String("call.result", "success"), attribute.Int("content_length", len(content)), attribute.String("duration", duration.String()))
return content, nil
}
// callOpenAIStream makes a streaming request to the OpenAI-compatible API
func (s *AIService) callOpenAIStream(ctx context.Context, userConfig *UserAIConfig, prompt, grammar string, chunks chan<- string) error {
if userConfig == nil {
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "userConfig is required")
}
_, span := observability.TraceAIFunction(ctx, "call_openai_stream",
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
attribute.String("ai.username", userConfig.Username),
attribute.Int("prompt.length", len(prompt)),
attribute.Bool("grammar.enabled", grammar != ""),
)
defer span.End()
// Validate input parameters
if userConfig.Provider == "" {
span.SetAttributes(attribute.String("stream.result", "empty_provider"))
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "provider is required")
}
if userConfig.Model == "" {
span.SetAttributes(attribute.String("stream.result", "empty_model"))
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "model is required")
}
if prompt == "" {
span.SetAttributes(attribute.String("stream.result", "empty_prompt"))
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "prompt cannot be empty")
}
if chunks == nil {
span.SetAttributes(attribute.String("stream.result", "nil_chunks_channel"))
return contextutils.WrapError(contextutils.ErrAIConfigInvalid, "chunks channel is required")
}
apiURL := ""
model := userConfig.Model
apiKey := userConfig.APIKey
// Look up the default URL from provider config
if s.cfg.Providers != nil {
for _, providerConfig := range s.cfg.Providers {
if providerConfig.Code == userConfig.Provider && providerConfig.URL != "" {
apiURL = providerConfig.URL
break
}
}
}
if apiURL == "" {
span.SetAttributes(attribute.String("stream.result", "no_url_configured"), attribute.String("provider", userConfig.Provider))
return contextutils.WrapErrorf(contextutils.ErrAIConfigInvalid, "no base URL configured for provider '%s'", userConfig.Provider)
}
userPrefix := ""
if userConfig.Username != "" {
userPrefix = fmt.Sprintf("[user=%s] ", userConfig.Username)
}
s.logger.Info(ctx, "AI Service Starting streaming request", map[string]interface{}{
"user_prefix": userPrefix,
"api_url": apiURL + "/chat/completions",
"model": model,
"provider": userConfig.Provider,
})
// Create messages with just the user prompt - grammar field will enforce JSON structure
messages := []Message{{Role: "user", Content: prompt}}
// Check if the provider supports grammar field
supportsGrammar := s.supportsGrammarField(userConfig.Provider)
reqBody := OpenAIRequest{
Model: model,
Messages: messages,
Temperature: 0.7,
MaxTokens: s.getMaxTokensForModel(userConfig.Provider, userConfig.Model),
Stream: true, // Enable streaming
}
// Only include grammar field if the provider supports it
if supportsGrammar && grammar != "" {
reqBody.Grammar = grammar
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
s.logger.Error(ctx, "Failed to marshal request", err, map[string]interface{}{
"user_prefix": userPrefix,
})
span.SetAttributes(attribute.String("stream.result", "marshal_failed"), attribute.String("error", err.Error()))
return contextutils.WrapErrorf(err, "failed to marshal streaming request body")
}
s.logger.Info(ctx, "AI Service Making streaming HTTP request", map[string]interface{}{
"user_prefix": userPrefix,
"api_url": apiURL + "/chat/completions",
})
req, err := http.NewRequestWithContext(ctx, "POST", apiURL+"/chat/completions", bytes.NewBuffer(jsonData))
if err != nil {
s.logger.Error(ctx, "Failed to create HTTP request", err, map[string]interface{}{
"user_prefix": userPrefix,
})
span.SetAttributes(attribute.String("stream.result", "request_creation_failed"), attribute.String("error", err.Error()))
return contextutils.WrapErrorf(err, "failed to create streaming HTTP request")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "text/event-stream")
req.Header.Set("Cache-Control", "no-cache")
if apiKey != "" {
req.Header.Set("Authorization", "Bearer "+apiKey)
s.logger.Info(ctx, "AI Service Using API key authentication", map[string]interface{}{
"user_prefix": userPrefix,
})
} else {
s.logger.Info(ctx, "AI Service No API key provided, using anonymous access", map[string]interface{}{
"user_prefix": userPrefix,
})
}
startTime := time.Now()
resp, err := s.httpClient.Do(req.WithContext(ctx))
if err != nil {
s.logger.Error(ctx, "HTTP request failed", err, map[string]interface{}{
"user_prefix": userPrefix,
})
span.SetAttributes(attribute.String("stream.result", "http_request_failed"), attribute.String("error", err.Error()))
return contextutils.WrapErrorf(contextutils.ErrAIRequestFailed, "http client error: %w", err)
}
defer func() {
if err := resp.Body.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{
"error": err.Error(),
})
}
}()
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
span.SetAttributes(attribute.String("stream.result", "http_error"), attribute.Int("status_code", resp.StatusCode), attribute.String("body", string(body)))
return contextutils.WrapErrorf(contextutils.ErrAIRequestFailed, "API request failed with status %d to %s: %s", resp.StatusCode, apiURL+"/chat/completions", string(body))
}
s.logger.Info(ctx, "AI Service Streaming response started", map[string]interface{}{
"user_prefix": userPrefix,
"duration": time.Since(startTime).String(),
})
// Read the streaming response
scanner := bufio.NewScanner(resp.Body)
var chunkCount int
var totalContentLength int
for scanner.Scan() {
line := scanner.Text()
// Skip empty lines and comments
if line == "" || strings.HasPrefix(line, ":") {
continue
}
// Parse Server-Sent Events format
if strings.HasPrefix(line, "data: ") {
data := strings.TrimPrefix(line, "data: ")
// Check for end of stream
if data == "[DONE]" {
break
}
// Parse the JSON chunk
var streamResp OpenAIStreamResponse
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
s.logger.Warn(ctx, "AI Service WARNING: Failed to parse streaming chunk", map[string]interface{}{
"error": err.Error(),
"data": data,
})
span.SetAttributes(attribute.String("stream.result", "chunk_parse_failed"), attribute.String("error", err.Error()), attribute.String("data", data))
continue
}
if streamResp.Error != nil {
span.SetAttributes(attribute.String("stream.result", "api_streaming_error"), attribute.String("error_message", streamResp.Error.Message), attribute.String("error_type", streamResp.Error.Type))
return contextutils.WrapErrorf(contextutils.ErrAIRequestFailed, "OpenAI API streaming error: %s", streamResp.Error.Message)
}
// Extract content from the chunk
if len(streamResp.Choices) > 0 && streamResp.Choices[0].Delta.Content != "" {
content := streamResp.Choices[0].Delta.Content
totalContentLength += len(content)
// Filter out thinking content for thinking models
filteredContent := s.filterThinkingContent(content, model)
if filteredContent != "" {
select {
case chunks <- filteredContent:
chunkCount++
case <-ctx.Done():
span.SetAttributes(attribute.String("stream.result", "context_cancelled"))
return ctx.Err()
}
}
}
// Check if streaming is finished
if len(streamResp.Choices) > 0 && streamResp.Choices[0].FinishReason != nil {
break
}
}
}
if err := scanner.Err(); err != nil {
span.SetAttributes(attribute.String("stream.result", "scanner_error"), attribute.String("error", err.Error()))
return contextutils.WrapErrorf(contextutils.ErrAIRequestFailed, "error reading streaming response: %w", err)
}
s.logger.Info(ctx, "AI Service Streaming response completed", map[string]interface{}{
"user_prefix": userPrefix,
"duration": time.Since(startTime).String(),
"chunk_count": chunkCount,
"total_content_length": totalContentLength,
})
span.SetAttributes(attribute.String("stream.result", "success"), attribute.Int("chunk_count", chunkCount), attribute.Int("total_content_length", totalContentLength), attribute.String("duration", time.Since(startTime).String()))
return nil
}
// filterThinkingContent filters out thinking sections for reasoning models
func (s *AIService) filterThinkingContent(content, model string) string {
// Check if this is a thinking/reasoning model
if !s.isThinkingModel(model) {
return content
}
// For thinking models, filter out content between <thinking> tags
if strings.Contains(content, "<thinking>") || strings.Contains(content, "</thinking>") {
return ""
}
if idx := strings.Index(content, "The answer is:"); idx != -1 {
answer := content[idx+len("The answer is:"):]
lines := strings.Split(answer, "\n")
for _, line := range lines {
trimmed := strings.TrimSpace(line)
if trimmed != "" {
return trimmed
}
}
return ""
}
trimmed := strings.TrimSpace(content)
if strings.HasPrefix(trimmed, "I need to") ||
strings.HasPrefix(trimmed, "Let me think") ||
strings.HasPrefix(trimmed, "First, I'll") {
return ""
}
return content
}
// isThinkingModel checks if the model is a reasoning/thinking model
func (s *AIService) isThinkingModel(model string) bool {
thinkingModels := []string{
"o1-preview",
"o1-mini",
"o1",
"qwen2.5-coder:32b",
"deepseek-r1",
"marco-o1",
"gpt-4",
"gpt-4-turbo",
"claude-3",
}
modelLower := strings.ToLower(model)
for _, thinkingModel := range thinkingModels {
if strings.Contains(modelLower, strings.ToLower(thinkingModel)) {
return true
}
}
return false
}
// cleanJSONResponse extracts JSON from markdown code blocks or returns the original response
func (s *AIService) cleanJSONResponse(ctx context.Context, response, provider string) string {
_, span := observability.TraceAIFunction(ctx, "clean_json_response",
attribute.String("ai.provider", provider),
attribute.Int("response.length", len(response)),
)
defer span.End()
// If the provider supports grammar field, we expect clean JSON
if s.supportsGrammarField(provider) {
return response
}
// For providers that don't support grammar field, clean up markdown code blocks
response = strings.TrimSpace(response)
// Remove markdown code block markers
if strings.HasPrefix(response, "```json") {
response = strings.TrimPrefix(response, "```json")
response = strings.TrimSuffix(response, "```")
} else if strings.HasPrefix(response, "```") {
response = strings.TrimPrefix(response, "```")
response = strings.TrimSuffix(response, "```")
}
return strings.TrimSpace(response)
}
func (s *AIService) parseQuestionsResponse(ctx context.Context, response, language, level string, qType models.QuestionType, provider string) (result0 []*models.Question, err error) {
if s == nil {
return nil, contextutils.WrapError(contextutils.ErrInternalError, "AIService instance is nil")
}
_, span := observability.TraceAIFunction(ctx, "parse_questions_response",
observability.AttributeQuestionType(qType),
observability.AttributeLanguage(language),
observability.AttributeLevel(level),
attribute.String("ai.provider", provider),
attribute.Int("response.length", len(response)),
)
defer observability.FinishSpan(span, &err)
defer func() {
if r := recover(); r != nil {
s.logger.Error(ctx, "PANIC in parseQuestionsResponse", nil, map[string]interface{}{
"panic": fmt.Sprintf("%v", r),
"response": response,
"stack": string(debug.Stack()),
})
span.SetAttributes(attribute.String("parse.result", "panic"), attribute.String("panic", fmt.Sprintf("%v", r)))
}
}()
// Validate input parameters
if response == "" {
span.SetAttributes(attribute.String("parse.result", "empty_response"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "AI provider returned empty response")
}
if language == "" {
span.SetAttributes(attribute.String("parse.result", "empty_language"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "language cannot be empty")
}
if level == "" {
span.SetAttributes(attribute.String("parse.result", "empty_level"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "level cannot be empty")
}
// Clean the response to handle markdown code blocks for providers without grammar support
cleanedResponse := s.cleanJSONResponse(ctx, response, provider)
if cleanedResponse == "" {
span.SetAttributes(attribute.String("parse.result", "empty_cleaned_response"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "AI provider returned empty response after cleaning")
}
// With grammar field enforcement, we should get clean JSON directly
// No need for complex extraction - just parse the response directly
var questions []map[string]interface{}
if err := json.Unmarshal([]byte(cleanedResponse), &questions); err != nil {
span.SetAttributes(attribute.String("parse.result", "json_unmarshal_failed"), attribute.String("error", err.Error()))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "failed to parse AI response as JSON: %w", err)
}
if len(questions) == 0 {
span.SetAttributes(attribute.String("parse.result", "no_questions_in_response"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "AI provider returned no questions in response")
}
var result []*models.Question
var validationErrors []string
var skippedCount int
for i, qData := range questions {
if qData == nil {
skippedCount++
span.SetAttributes(attribute.String("parse.result", "nil_question_data"), attribute.Int("question_index", i))
continue
}
question, err := s.createQuestionFromData(ctx, qData, language, level, qType)
if err != nil {
// Try to extract more info about the failure
var failedField, failedValue string
for k, v := range qData {
if v == nil || v == "" {
failedField = k
failedValue = fmt.Sprintf("%v", v)
break
}
}
validationErrors = append(validationErrors, fmt.Sprintf("question %d: %v (field: %s, value: %s)", i+1, err, failedField, failedValue))
span.SetAttributes(attribute.String("parse.result", "question_creation_failed"), attribute.Int("question_index", i), attribute.String("error", err.Error()))
continue
}
if question == nil {
skippedCount++
span.SetAttributes(attribute.String("parse.result", "nil_question_after_creation"), attribute.Int("question_index", i))
continue
}
// Coerce correct_answer to int if it's a float64 (for schema validation)
if m := question.Content; m != nil {
if v, ok := m["correct_answer"]; ok {
switch val := v.(type) {
case float64:
m["correct_answer"] = int(val)
}
}
}
valid, err := s.ValidateQuestionSchema(ctx, qType, question)
if err != nil {
validationErrors = append(validationErrors, fmt.Sprintf("question %d schema validation error: %v", i+1, err))
span.SetAttributes(attribute.String("parse.result", "schema_validation_error"), attribute.Int("question_index", i), attribute.String("error", err.Error()))
}
if !valid {
SchemaValidationMu.Lock()
SchemaValidationFailures[qType]++
if err != nil {
SchemaValidationFailureDetails[qType] = append(SchemaValidationFailureDetails[qType], err.Error())
} else {
SchemaValidationFailureDetails[qType] = append(SchemaValidationFailureDetails[qType], "validation failed")
}
if len(SchemaValidationFailureDetails[qType]) > 10 {
SchemaValidationFailureDetails[qType] = SchemaValidationFailureDetails[qType][len(SchemaValidationFailureDetails[qType])-10:]
}
SchemaValidationMu.Unlock()
skippedCount++
span.SetAttributes(attribute.String("parse.result", "schema_validation_failed"), attribute.Int("question_index", i))
continue // skip invalid question
}
result = append(result, question)
}
// Log validation summary
if len(validationErrors) > 0 {
s.logger.Warn(ctx, "AI Service WARNING: validation errors in response", map[string]interface{}{
"validation_errors_count": len(validationErrors),
"validation_errors": strings.Join(validationErrors, "; "),
})
span.SetAttributes(attribute.String("parse.result", "validation_errors"), attribute.String("errors", strings.Join(validationErrors, "; ")))
}
if len(result) == 0 {
span.SetAttributes(attribute.String("parse.result", "no_valid_questions"), attribute.Int("total_questions", len(questions)), attribute.Int("skipped_count", skippedCount))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "AI provider returned only invalid or empty questions (total: %d, skipped: %d)", len(questions), skippedCount)
}
span.SetAttributes(attribute.String("parse.result", "success"), attribute.Int("valid_questions", len(result)), attribute.Int("total_questions", len(questions)), attribute.Int("skipped_count", skippedCount))
return result, nil
}
// createQuestionFromData creates a Question from parsed JSON data
func (s *AIService) createQuestionFromData(ctx context.Context, data map[string]interface{}, language, level string, qType models.QuestionType) (result0 *models.Question, err error) {
if s == nil {
return nil, contextutils.WrapError(contextutils.ErrInternalError, "AIService instance is nil")
}
_, span := observability.TraceAIFunction(ctx, "create_question_from_data",
observability.AttributeQuestionType(qType),
observability.AttributeLanguage(language),
observability.AttributeLevel(level),
attribute.Int("data.fields", len(data)),
)
defer observability.FinishSpan(span, &err)
if data == nil {
span.SetAttributes(attribute.String("creation.result", "nil_data"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "question data is nil")
}
// Validate required parameters
if language == "" {
span.SetAttributes(attribute.String("creation.result", "empty_language"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "language cannot be empty")
}
if level == "" {
span.SetAttributes(attribute.String("creation.result", "empty_level"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "level cannot be empty")
}
if ok, errMsg := s.validateQuestionContent(ctx, qType, data); !ok {
missingFields := []string{}
for k, v := range data {
if v == nil || v == "" {
missingFields = append(missingFields, k)
}
}
if len(missingFields) > 0 {
span.SetAttributes(attribute.String("creation.result", "validation_failed_with_missing_fields"), attribute.String("missing_fields", strings.Join(missingFields, ",")))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "invalid question content structure: %s. Missing or empty fields: %v", errMsg, missingFields)
}
span.SetAttributes(attribute.String("creation.result", "validation_failed"), attribute.String("error", errMsg))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "invalid question content structure: %s", errMsg)
}
// Defensive: For reading comprehension, check passage, question, options, correct_answer
if qType == models.ReadingComprehension {
if _, ok := data["passage"].(string); !ok {
span.SetAttributes(attribute.String("creation.result", "reading_missing_passage"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "reading comprehension question missing or invalid 'passage' field")
}
if _, ok := data["question"].(string); !ok {
span.SetAttributes(attribute.String("creation.result", "reading_missing_question"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "reading comprehension question missing or invalid 'question' field")
}
options, ok := data["options"].([]interface{})
if !ok || len(options) != 4 {
span.SetAttributes(attribute.String("creation.result", "reading_invalid_options"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "reading comprehension question missing or invalid 'options' field (must be array of 4 strings)")
}
for i, opt := range options {
if _, ok := opt.(string); !ok {
span.SetAttributes(attribute.String("creation.result", "reading_invalid_option_type"), attribute.Int("option_index", i))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "reading comprehension question 'options' must be array of strings, found invalid type at index %d", i)
}
}
if _, ok := data["correct_answer"]; !ok {
span.SetAttributes(attribute.String("creation.result", "reading_missing_correct_answer"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "reading comprehension question missing 'correct_answer' field")
}
}
// Parse correct_answer as index (integer)
var correctAnswerIndex int
if correctAnswerRaw, exists := data["correct_answer"]; exists {
switch v := correctAnswerRaw.(type) {
case int:
correctAnswerIndex = v
case float64:
correctAnswerIndex = int(v)
case string:
// Handle string indices like "0", "1", "2", "3"
if idx, err := strconv.Atoi(v); err == nil {
correctAnswerIndex = idx
} else {
// Handle answer text - find index in options
if options, ok := data["options"].([]interface{}); ok {
found := false
for i, opt := range options {
if optStr, ok := opt.(string); ok && optStr == v {
correctAnswerIndex = i
found = true
break
}
}
if !found {
span.SetAttributes(attribute.String("creation.result", "correct_answer_not_found_in_options"))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "correct_answer '%s' not found in options", v)
}
} else {
span.SetAttributes(attribute.String("creation.result", "no_options_for_text_answer"))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "correct_answer is text '%s' but no options available to match against", v)
}
}
default:
span.SetAttributes(attribute.String("creation.result", "invalid_correct_answer_type"), attribute.String("type", fmt.Sprintf("%T", v)))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "invalid correct_answer type: %T", v)
}
} else {
span.SetAttributes(attribute.String("creation.result", "missing_correct_answer"))
return nil, contextutils.WrapError(contextutils.ErrAIResponseInvalid, "missing correct_answer field")
}
// Validate correct answer index
if options, ok := data["options"].([]interface{}); ok {
if correctAnswerIndex < 0 || correctAnswerIndex >= len(options) {
span.SetAttributes(attribute.String("creation.result", "invalid_correct_answer_index"), attribute.Int("index", correctAnswerIndex), attribute.Int("options_count", len(options)))
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "correct_answer index %d is out of range (0-%d)", correctAnswerIndex, len(options)-1)
}
}
// Note: Removed backend shuffling logic - frontend handles shuffling
// This prevents mismatch between backend and frontend answer indices
// Get explanation or provide default
explanation, _ := data["explanation"].(string)
if explanation == "" {
// Provide a default explanation based on question type
switch qType {
case models.Vocabulary:
explanation = "This vocabulary question tests your knowledge of words in context."
case models.ReadingComprehension:
explanation = "This reading comprehension question tests your understanding of the passage."
case models.FillInBlank:
explanation = "This fill-in-the-blank question tests your grammar and vocabulary knowledge."
case models.QuestionAnswer:
explanation = "This question tests your conversational and practical language skills."
default:
explanation = "This question tests your language skills."
}
// Add the explanation to the data for schema validation
data["explanation"] = explanation
}
question := &models.Question{
Type: qType,
Language: language,
Level: level,
DifficultyScore: s.getDifficultyScore(level),
Content: data,
CorrectAnswer: correctAnswerIndex,
Explanation: explanation,
CreatedAt: time.Now(),
}
span.SetAttributes(attribute.String("creation.result", "success"))
return question, nil
}
func (s *AIService) parseQuestionResponse(ctx context.Context, response, language, level string, qType models.QuestionType, provider string) (result0 *models.Question, err error) {
_, span := observability.TraceAIFunction(ctx, "parse_question_response",
observability.AttributeQuestionType(qType),
observability.AttributeLanguage(language),
observability.AttributeLevel(level),
attribute.String("ai.provider", provider),
attribute.Int("response.length", len(response)),
)
defer observability.FinishSpan(span, &err)
// Clean the response to handle markdown code blocks for providers without grammar support
cleanedResponse := s.cleanJSONResponse(ctx, response, provider)
// With grammar field enforcement, we should get clean JSON directly
// No need for complex extraction - just parse the response directly
var data map[string]interface{}
if err := json.Unmarshal([]byte(cleanedResponse), &data); err != nil {
s.logger.Error(ctx, "Failed to parse JSON response", err, map[string]interface{}{
"raw_response": response,
})
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "failed to parse AI response as JSON: %w", err)
}
question, err := s.createQuestionFromData(ctx, data, language, level, qType)
if err != nil {
s.logger.Error(ctx, "Failed to create question from data", err, map[string]interface{}{
"raw_question_data": data,
"full_model_response": response,
})
return nil, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "failed to create question: %w", err)
}
valid, err := s.ValidateQuestionSchema(ctx, qType, question)
if err != nil {
s.logger.Error(ctx, "Schema validation error for question", err, nil)
}
if !valid {
SchemaValidationMu.Lock()
SchemaValidationFailures[qType]++
if err != nil {
SchemaValidationFailureDetails[qType] = append(SchemaValidationFailureDetails[qType], err.Error())
} else {
SchemaValidationFailureDetails[qType] = append(SchemaValidationFailureDetails[qType], "validation failed")
}
if len(SchemaValidationFailureDetails[qType]) > 10 {
SchemaValidationFailureDetails[qType] = SchemaValidationFailureDetails[qType][len(SchemaValidationFailureDetails[qType])-10:]
}
SchemaValidationMu.Unlock()
}
return question, nil
}
func (s *AIService) getDifficultyScore(level string) float64 {
// Look up the level in the language levels configuration
if s.cfg != nil && s.cfg.LanguageLevels != nil {
for _, langConfig := range s.cfg.LanguageLevels {
for i, lvl := range langConfig.Levels {
if lvl == level {
// Return a score based on the level's position (0.0 to 1.0)
return float64(i) / float64(len(langConfig.Levels)-1)
}
}
}
}
// Default to middle difficulty if level not found
return 0.5
}
func (s *AIService) validateQuestionContent(ctx context.Context, qType models.QuestionType, content map[string]interface{}) (bool, string) {
_, span := observability.TraceAIFunction(ctx, "validate_question_content",
observability.AttributeQuestionType(qType),
attribute.Int("content.fields", len(content)),
)
defer span.End()
// Validate input parameters
if content == nil {
span.SetAttributes(attribute.String("validation.result", "nil_content"))
return false, "question content cannot be nil"
}
requiredFields := make(map[string]func(interface{}) bool)
isString := func(v interface{}) bool {
if v == nil {
return false
}
_, ok := v.(string)
return ok && v.(string) != ""
}
isStringSlice := func(v interface{}) bool {
if v == nil {
return false
}
if slice, ok := v.([]interface{}); ok {
if len(slice) < 4 {
return false
}
for _, item := range slice {
if item == nil {
return false
}
if _, ok := item.(string); !ok {
return false
}
if item.(string) == "" {
return false
}
}
return true
}
return false
}
isCorrectAnswer := func(v interface{}) bool {
if v == nil {
return false
}
switch val := v.(type) {
case int:
return val >= 0
case float64:
return val >= 0 && val == float64(int(val)) // Must be whole number
case string:
// Accept string indices like "0", "1", "2", "3" or answer text
if _, err := strconv.Atoi(val); err == nil {
return true
}
// Or accept answer text that matches one of the options
if options, ok := content["options"].([]interface{}); ok {
for _, opt := range options {
if optStr, ok := opt.(string); ok && optStr == val {
return true
}
}
}
return false
default:
return false
}
}
switch qType {
case models.Vocabulary:
requiredFields["sentence"] = isString
requiredFields["question"] = isString
requiredFields["options"] = isStringSlice
for field, validator := range requiredFields {
if !validator(content[field]) {
span.SetAttributes(attribute.String("validation.result", "field_validation_failed"), attribute.String("field", field))
return false, fmt.Sprintf("[Vocabulary] Validation failed for field '%s': %v", field, content[field])
}
}
sentence, _ := content["sentence"].(string)
targetWord, _ := content["question"].(string)
options, _ := content["options"].([]interface{})
if sentence == "" || targetWord == "" || len(options) != 4 {
span.SetAttributes(attribute.String("validation.result", "vocabulary_structure_failed"))
return false, "[Vocabulary] Validation failed: missing or invalid sentence/question/options"
}
if !strings.Contains(sentence, targetWord) {
span.SetAttributes(attribute.String("validation.result", "vocabulary_word_not_found"))
return false, fmt.Sprintf("[Vocabulary] Validation failed: question '%s' not found in sentence '%s'", targetWord, sentence)
}
span.SetAttributes(attribute.String("validation.result", "valid"))
return true, ""
case models.ReadingComprehension:
requiredFields["passage"] = isString
requiredFields["question"] = isString
requiredFields["options"] = isStringSlice
requiredFields["correct_answer"] = isCorrectAnswer
for field, validator := range requiredFields {
if !validator(content[field]) {
span.SetAttributes(attribute.String("validation.result", "field_validation_failed"), attribute.String("field", field))
return false, fmt.Sprintf("[ReadingComprehension] Validation failed for field '%s': %v", field, content[field])
}
}
passage, _ := content["passage"].(string)
if passage == "" {
span.SetAttributes(attribute.String("validation.result", "reading_passage_empty"))
return false, "[ReadingComprehension] Validation failed: passage cannot be empty"
}
span.SetAttributes(attribute.String("validation.result", "valid"))
return true, ""
case models.FillInBlank:
// Fill-in-blank questions now use multiple choice format like all other types
requiredFields["question"] = isString
requiredFields["options"] = isStringSlice
requiredFields["correct_answer"] = isCorrectAnswer
for field, validator := range requiredFields {
if !validator(content[field]) {
span.SetAttributes(attribute.String("validation.result", "field_validation_failed"), attribute.String("field", field))
return false, fmt.Sprintf("[FillInBlank] Validation failed for field '%s': %v", field, content[field])
}
}
span.SetAttributes(attribute.String("validation.result", "valid"))
return true, ""
case models.QuestionAnswer:
// Question-answer questions now use multiple choice format like all other types
requiredFields["question"] = isString
requiredFields["options"] = isStringSlice
requiredFields["correct_answer"] = isCorrectAnswer
for field, validator := range requiredFields {
if !validator(content[field]) {
span.SetAttributes(attribute.String("validation.result", "field_validation_failed"), attribute.String("field", field))
return false, fmt.Sprintf("[QuestionAnswer] Validation failed for field '%s': %v", field, content[field])
}
}
span.SetAttributes(attribute.String("validation.result", "valid"))
return true, ""
}
// If we reach here, it's an unknown question type
span.SetAttributes(attribute.String("validation.result", "unknown_type"))
return false, fmt.Sprintf("unknown question type: %v", qType)
}
// GetConcurrencyStats returns current concurrency metrics
func (s *AIService) GetConcurrencyStats() ConcurrencyStats {
s.statsMu.RLock()
s.concurrencyMu.RLock()
defer s.statsMu.RUnlock()
defer s.concurrencyMu.RUnlock()
// Count active requests globally and per user
queuedRequests := 0 // Currently we don't queue, we fail fast
userActiveCount := make(map[string]int)
for username, count := range s.userRequestCount {
if count > 0 {
userActiveCount[username] = count
}
}
return ConcurrencyStats{
ActiveRequests: s.activeRequests,
MaxConcurrent: s.maxConcurrent,
QueuedRequests: queuedRequests,
TotalRequests: s.totalRequests,
UserActiveCount: userActiveCount,
MaxPerUser: s.maxPerUser,
}
}
// acquireGlobalSlot attempts to acquire a global concurrency slot
func (s *AIService) acquireGlobalSlot(ctx context.Context) error {
select {
case s.globalSemaphore <- struct{}{}:
return nil
case <-ctx.Done():
return contextutils.WrapErrorf(contextutils.ErrTimeout, "request cancelled while waiting for global AI slot: %w", ctx.Err())
default:
return contextutils.WrapErrorf(contextutils.ErrServiceUnavailable, "AI service at capacity (%d concurrent requests), please try again", s.maxConcurrent)
}
}
// releaseGlobalSlot releases a global concurrency slot
func (s *AIService) releaseGlobalSlot(ctx context.Context) {
s.concurrencyMu.Lock()
defer s.concurrencyMu.Unlock()
select {
case <-s.globalSemaphore:
// Successfully released a slot
s.statsMu.Lock()
if s.activeRequests > 0 {
s.activeRequests--
}
s.statsMu.Unlock()
default:
// No slot was acquired
s.logger.Warn(ctx, "WARNING: Attempted to release global AI slot but none were acquired", nil)
}
}
// acquireUserSlot acquires a user-specific concurrency slot
func (s *AIService) acquireUserSlot(_ context.Context, username string) error {
s.concurrencyMu.Lock()
defer s.concurrencyMu.Unlock()
currentCount := s.userRequestCount[username]
if currentCount >= s.maxPerUser {
return contextutils.WrapErrorf(contextutils.ErrServiceUnavailable, "user concurrency limit exceeded for %s: %d/%d", username, currentCount, s.maxPerUser)
}
s.userRequestCount[username] = currentCount + 1
return nil
}
// releaseUserSlot releases a user-specific concurrency slot
func (s *AIService) releaseUserSlot(ctx context.Context, username string) {
s.concurrencyMu.Lock()
defer s.concurrencyMu.Unlock()
currentCount := s.userRequestCount[username]
if currentCount > 0 {
s.userRequestCount[username] = currentCount - 1
} else {
s.logger.Warn(ctx, "WARNING: Attempted to release user AI slot but none were acquired", map[string]interface{}{
"username": username,
})
}
}
// incrementTotalRequests increments the total request counter
func (s *AIService) incrementTotalRequests() {
s.statsMu.Lock()
defer s.statsMu.Unlock()
s.totalRequests++
}
// withConcurrencyControl wraps an AI operation with concurrency limits
func (s *AIService) withConcurrencyControl(ctx context.Context, username string, operation func() error) error {
// Check if service is shutting down
if s.isShutdown() {
return contextutils.WrapError(contextutils.ErrServiceUnavailable, "AI service is shutting down")
}
// Increment total request counter
s.incrementTotalRequests()
// Acquire global slot
if err := s.acquireGlobalSlot(ctx); err != nil {
return err
}
// Track active request
s.statsMu.Lock()
s.activeRequests++
s.statsMu.Unlock()
defer func() {
s.releaseGlobalSlot(ctx)
}()
// Acquire per-user slot
if err := s.acquireUserSlot(ctx, username); err != nil {
return err
}
defer s.releaseUserSlot(ctx, username)
// Execute the actual operation
return operation()
}
// supportsGrammarField checks if the provider supports the grammar field
func (s *AIService) supportsGrammarField(provider string) bool {
// Check if the provider supports grammar field
if s.cfg.Providers == nil {
return false
}
for _, providerConfig := range s.cfg.Providers {
if providerConfig.Code == provider {
return providerConfig.SupportsGrammar
}
}
return false
}
// getQuestionBatchSize returns the maximum number of questions that can be generated in a single request for the given provider
func (s *AIService) getQuestionBatchSize(provider string) int {
// Get the batch size for the provider
if s.cfg.Providers == nil {
return 1 // Default batch size
}
for _, p := range s.cfg.Providers {
if p.Code == provider {
if p.QuestionBatchSize > 0 {
return p.QuestionBatchSize
}
break
}
}
return 1 // Default batch size
}
// GetQuestionBatchSize returns the maximum number of questions that can be generated in a single request for the given provider
func (s *AIService) GetQuestionBatchSize(provider string) int {
return s.getQuestionBatchSize(provider)
}
// VarietyService returns the variety service used by the AI service
func (s *AIService) VarietyService() *VarietyService {
return s.varietyService
}
// TemplateManager exposes template rendering and example loading for prompts
func (s *AIService) TemplateManager() *AITemplateManager {
return s.templateManager
}
// SupportsGrammarField reports whether the provider supports the grammar field
func (s *AIService) SupportsGrammarField(provider string) bool {
return s.supportsGrammarField(provider)
}
// CallWithPrompt sends a raw prompt (and optional grammar) to the provider and returns the response
func (s *AIService) CallWithPrompt(ctx context.Context, userConfig *UserAIConfig, prompt, grammar string) (string, error) {
return s.callOpenAI(ctx, userConfig, prompt, grammar)
}
// Package services provides embedded templates for AI service prompts
package services
import (
"embed"
"fmt"
"strings"
"text/template"
contextutils "quizapp/internal/utils"
)
//go:embed templates/*.tmpl
var aiTemplatesFS embed.FS
//go:embed templates/examples/*.json
var exampleFilesFS embed.FS
// Template names as constants
const (
BatchQuestionPromptTemplate = "batch_question_prompt.tmpl"
ChatPromptTemplate = "chat_prompt.tmpl"
JSONStructureGuidanceTemplate = "json_structure_guidance.tmpl"
AIFixPromptTemplate = "ai_fix_prompt.tmpl"
)
// AITemplateData holds data for rendering AI prompt templates
type AITemplateData struct {
// Common fields
Language string
Level string
QuestionType string
Topic string
RecentQuestionHistory []string
ReportReasons []string
Count int // For batch generation
// Variety fields for question generation
TopicCategory string
GrammarFocus string
VocabularyDomain string
Scenario string
StyleModifier string
DifficultyModifier string
TimeContext string
// Schema and formatting
SchemaForPrompt string // for direct inclusion in prompt for non-grammar providers
ExampleContent string // for including example in prompt
CurrentQuestionJSON string // the actual question JSON to pass into ai-fix prompt
AdditionalContext string // optional freeform context provided by admin when requesting AI fix
// Explanation specific
Question string
UserAnswer string
CorrectAnswer string // The text of the correct answer for explanations
// Chat specific
Passage string
Options []string
IsCorrect *bool
ConversationHistory []ChatMessage
UserMessage string
// Priority-aware generation fields (NEW)
UserWeakAreas []string
HighPriorityTopics []string
GapAnalysis map[string]int
FocusOnWeakAreas bool
FreshQuestionRatio float64
PriorityDistribution map[string]int
}
// ChatMessage represents a chat message for templates
type ChatMessage struct {
Role string
Content string
}
// AITemplateManager manages AI prompt templates
type AITemplateManager struct {
templates *template.Template
}
// NewAITemplateManager creates a new template manager
func NewAITemplateManager() (result0 *AITemplateManager, err error) {
templates, err := template.New("").ParseFS(aiTemplatesFS, "templates/*.tmpl")
if err != nil {
return nil, err
}
return &AITemplateManager{
templates: templates,
}, nil
}
// RenderTemplate renders a template with the given data
func (tm *AITemplateManager) RenderTemplate(templateName string, data AITemplateData) (result0 string, err error) {
var buf strings.Builder
err = tm.templates.ExecuteTemplate(&buf, templateName, data)
if err != nil {
return "", err
}
return buf.String(), nil
}
// LoadExample loads the example JSON for a specific question type
func (tm *AITemplateManager) LoadExample(questionType string) (result0 string, err error) {
examplePath := fmt.Sprintf("templates/examples/%s_example.json", questionType)
content, err := exampleFilesFS.ReadFile(examplePath)
if err != nil {
return "", contextutils.WrapErrorf(contextutils.ErrInternalError, "failed to load example for %s: %w", questionType, err)
}
return string(content), nil
}
package services
import (
"context"
"database/sql"
"errors"
"time"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"quizapp/internal/observability"
)
// CleanupService handles database maintenance and cleanup tasks
type CleanupService struct {
db *sql.DB
logger *observability.Logger
}
// NewCleanupServiceWithLogger creates a new cleanup service with logger
func NewCleanupServiceWithLogger(db *sql.DB, logger *observability.Logger) *CleanupService {
return &CleanupService{
db: db,
logger: logger,
}
}
// CleanupLegacyQuestionTypes removes questions with unsupported question types
func (c *CleanupService) CleanupLegacyQuestionTypes(ctx context.Context) (err error) {
ctx, span := observability.TraceCleanupFunction(ctx, "cleanup_legacy_question_types")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if database is available
if c.db == nil {
return errors.New("database connection not available")
}
// Get count of legacy questions first
var count int
err = c.db.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM questions
WHERE type NOT IN ('vocabulary', 'fill_blank', 'qa', 'reading_comprehension')
`).Scan(&count)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
span.SetAttributes(attribute.Int("cleanup.legacy_questions_count", count))
if count == 0 {
c.logger.Info(ctx, "No legacy question types found to cleanup", map[string]interface{}{})
span.SetAttributes(attribute.String("cleanup.result", "no_legacy_questions"))
return nil
}
c.logger.Info(ctx, "Found questions with legacy types to cleanup", map[string]interface{}{"count": count})
// Delete questions with unsupported types
result, err := c.db.ExecContext(ctx, `
DELETE FROM questions
WHERE type NOT IN ('vocabulary', 'fill_blank', 'qa', 'reading_comprehension')
`)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
span.SetAttributes(
attribute.Int64("cleanup.rows_affected", rowsAffected),
attribute.String("cleanup.result", "success"),
)
c.logger.Info(ctx, "Successfully cleaned up questions with legacy types", map[string]interface{}{"rows_affected": rowsAffected})
return nil
}
// CleanupOrphanedResponses removes user responses for questions that no longer exist
func (c *CleanupService) CleanupOrphanedResponses(ctx context.Context) (err error) {
ctx, span := observability.TraceCleanupFunction(ctx, "cleanup_orphaned_responses")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if database is available
if c.db == nil {
return errors.New("database connection not available")
}
var count int
err = c.db.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM user_responses ur
LEFT JOIN questions q ON ur.question_id = q.id
WHERE q.id IS NULL
`).Scan(&count)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
span.SetAttributes(attribute.Int("cleanup.orphaned_responses_count", count))
if count == 0 {
c.logger.Info(ctx, "No orphaned responses found to cleanup", map[string]interface{}{})
span.SetAttributes(attribute.String("cleanup.result", "no_orphaned_responses"))
return nil
}
c.logger.Info(ctx, "Found orphaned responses to cleanup", map[string]interface{}{"count": count})
result, err := c.db.ExecContext(ctx, `
DELETE FROM user_responses
WHERE question_id NOT IN (SELECT id FROM questions)
`)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
rowsAffected, err := result.RowsAffected()
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
span.SetAttributes(
attribute.Int64("cleanup.rows_affected", rowsAffected),
attribute.String("cleanup.result", "success"),
)
c.logger.Info(ctx, "Successfully cleaned up orphaned responses", map[string]interface{}{"rows_affected": rowsAffected})
return nil
}
// RunFullCleanup performs all cleanup operations
func (c *CleanupService) RunFullCleanup(ctx context.Context) (err error) {
ctx, span := observability.TraceCleanupFunction(ctx, "run_full_cleanup")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
span.SetAttributes(attribute.String("cleanup.start_time", time.Now().Format(time.RFC3339)))
c.logger.Info(ctx, "Starting database cleanup", map[string]interface{}{"start_time": time.Now().Format(time.RFC3339)})
if err = c.CleanupLegacyQuestionTypes(ctx); err != nil {
c.logger.Error(ctx, "Failed to cleanup legacy question types", err, map[string]interface{}{})
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
if err := c.CleanupOrphanedResponses(ctx); err != nil {
c.logger.Error(ctx, "Failed to cleanup orphaned responses", err, map[string]interface{}{})
span.SetAttributes(attribute.String("error", err.Error()))
return err
}
span.SetAttributes(
attribute.String("cleanup.end_time", time.Now().Format(time.RFC3339)),
attribute.String("cleanup.result", "success"),
)
c.logger.Info(ctx, "Database cleanup completed successfully", map[string]interface{}{"end_time": time.Now().Format(time.RFC3339)})
return nil
}
// GetCleanupStats returns statistics about cleanup operations
func (c *CleanupService) GetCleanupStats(ctx context.Context) (result0 map[string]int, err error) {
ctx, span := observability.TraceCleanupFunction(ctx, "get_cleanup_stats")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if database is available
if c.db == nil {
return nil, errors.New("database connection not available")
}
stats := make(map[string]int)
// Count legacy question types
var legacyCount int
err = c.db.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM questions
WHERE type NOT IN ('vocabulary', 'fill_blank', 'qa', 'reading_comprehension')
`).Scan(&legacyCount)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, err
}
stats["legacy_questions"] = legacyCount
// Count orphaned responses
var orphanedCount int
err = c.db.QueryRowContext(ctx, `
SELECT COUNT(*)
FROM user_responses ur
LEFT JOIN questions q ON ur.question_id = q.id
WHERE q.id IS NULL
`).Scan(&orphanedCount)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, err
}
stats["orphaned_responses"] = orphanedCount
span.SetAttributes(
attribute.Int("cleanup.stats.legacy_questions", legacyCount),
attribute.Int("cleanup.stats.orphaned_responses", orphanedCount),
)
return stats, nil
}
package services
import (
"context"
"database/sql"
"fmt"
"time"
"quizapp/internal/api"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// DailyQuestionServiceInterface defines the interface for daily question operations
type DailyQuestionServiceInterface interface {
AssignDailyQuestions(ctx context.Context, userID int, date time.Time) error
RegenerateDailyQuestions(ctx context.Context, userID int, date time.Time) error
GetDailyQuestions(ctx context.Context, userID int, date time.Time) ([]*models.DailyQuestionAssignmentWithQuestion, error)
MarkQuestionCompleted(ctx context.Context, userID, questionID int, date time.Time) error
ResetQuestionCompleted(ctx context.Context, userID, questionID int, date time.Time) error
SubmitDailyQuestionAnswer(ctx context.Context, userID, questionID int, date time.Time, userAnswerIndex int) (*api.AnswerResponse, error)
GetAvailableDates(ctx context.Context, userID int) ([]time.Time, error)
GetDailyProgress(ctx context.Context, userID int, date time.Time) (*models.DailyProgress, error)
GetDailyQuestionsCount(ctx context.Context, userID int, date time.Time) (int, error)
GetCompletedDailyQuestionsCount(ctx context.Context, userID int, date time.Time) (int, error)
GetQuestionHistory(ctx context.Context, userID, questionID, days int) ([]*models.DailyQuestionHistory, error)
}
// DailyQuestionService implements daily question assignment and management
type DailyQuestionService struct {
db *sql.DB
logger *observability.Logger
questionService QuestionServiceInterface
learningService LearningServiceInterface
}
// NewDailyQuestionService creates a new DailyQuestionService instance
func NewDailyQuestionService(db *sql.DB, logger *observability.Logger, questionService QuestionServiceInterface, learningService LearningServiceInterface) *DailyQuestionService {
return &DailyQuestionService{
db: db,
logger: logger,
questionService: questionService,
learningService: learningService,
}
}
// AssignDailyQuestions assigns 10 random questions to a user for a specific date
func (s *DailyQuestionService) AssignDailyQuestions(ctx context.Context, userID int, date time.Time) (err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "AssignDailyQuestions",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Get user to determine language and level preferences
user, err := s.getUserByID(ctx, userID)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get user")
}
if user == nil {
return contextutils.ErrorWithContextf("user not found: %d", userID)
}
language := user.PreferredLanguage.String
level := user.CurrentLevel.String
if language == "" || level == "" {
return contextutils.ErrorWithContextf("user missing language or level preferences")
}
// Get user's daily goal from learning preferences
prefs, perr := s.learningService.GetUserLearningPreferences(ctx, userID)
if perr != nil {
span.RecordError(perr)
return contextutils.WrapError(perr, "failed to get user learning preferences")
}
goal := 10
if prefs != nil && prefs.DailyGoal > 0 {
goal = prefs.DailyGoal
}
// Check existing assignments and only fill missing slots up to the user's goal
existingCount, err := s.GetDailyQuestionsCount(ctx, userID, date)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to check existing assignments")
}
if existingCount >= goal {
s.logger.Info(ctx, "Daily questions already assigned for date", map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
"count": existingCount,
"goal": goal,
})
return nil // Already assigned
}
// Request more candidates than strictly needed to allow filtering out already-assigned questions
buffer := 10 // request this many extra candidates beyond the user's goal
reqLimit := goal + buffer
// Get adaptive questions using an expanded limit so we can filter and still meet goal
questionsWithStats, err := s.questionService.GetAdaptiveQuestionsForDaily(ctx, userID, language, level, reqLimit)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get adaptive questions for assignment")
}
if len(questionsWithStats) == 0 {
// Gather diagnostics to explain why no questions were available
var candidateIDs []int
candidateCount := 0
totalMatching := 0
if s.questionService != nil {
if candidates, qerr := s.questionService.GetAdaptiveQuestionsForDaily(ctx, userID, language, level, 50); qerr == nil && candidates != nil {
candidateCount = len(candidates)
for i, q := range candidates {
if i >= 10 {
break
}
if q != nil {
candidateIDs = append(candidateIDs, q.ID)
}
}
}
if _, total, terr := s.questionService.GetAllQuestionsPaginated(ctx, 1, 1, "", "", "", language, level, nil); terr == nil {
totalMatching = total
}
}
return &NoQuestionsAvailableError{
Language: language,
Level: level,
CandidateIDs: candidateIDs,
CandidateCount: candidateCount,
TotalMatching: totalMatching,
}
}
// Filter out questions that are already assigned for this user/date to
// avoid selecting already-inserted questions and thus underfilling the goal.
assignedIDs := make(map[int]bool)
rows, qerr := s.db.QueryContext(ctx, `SELECT question_id FROM daily_question_assignments WHERE user_id = $1 AND assignment_date = $2`, userID, date)
if qerr == nil {
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
for rows.Next() {
var qid int
if err := rows.Scan(&qid); err == nil {
assignedIDs[qid] = true
}
}
}
// Convert QuestionWithStats to Question for assignment, skipping already-assigned
var questions []models.Question
for _, qws := range questionsWithStats {
if qws == nil || qws.Question == nil {
continue
}
if assignedIDs[qws.ID] {
// already assigned for this date, skip
continue
}
questions = append(questions, *qws.Question)
}
// Only insert up to the number of slots we need to fill
toAssign := goal - existingCount
if toAssign < 0 {
toAssign = 0
}
if len(questions) > toAssign {
questions = questions[:toAssign]
}
// Begin transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to begin transaction")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Error(ctx, "Failed to rollback transaction", rollbackErr, map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
})
}
}
}()
// Insert assignments (idempotent via conditional INSERT to avoid duplicate rows)
insertQuery := `
INSERT INTO daily_question_assignments (user_id, question_id, assignment_date, created_at)
SELECT $1, $2, $3, $4
WHERE NOT EXISTS (
SELECT 1 FROM daily_question_assignments WHERE user_id = $1 AND question_id = $2 AND assignment_date = $3
)
`
for _, question := range questions {
_, err = tx.ExecContext(ctx, insertQuery, userID, question.ID, date, time.Now())
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to insert assignment")
}
}
// Commit transaction
err = tx.Commit()
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to commit transaction")
}
s.logger.Info(ctx, "Daily questions assigned successfully", map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
"count": len(questions),
})
return nil
}
// RegenerateDailyQuestions clears existing daily question assignments and creates new ones for a user and date
func (s *DailyQuestionService) RegenerateDailyQuestions(ctx context.Context, userID int, date time.Time) (err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "RegenerateDailyQuestions",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Get user to determine language and level preferences
user, err := s.getUserByID(ctx, userID)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get user")
}
if user == nil {
return contextutils.ErrorWithContextf("user not found: %d", userID)
}
language := user.PreferredLanguage.String
level := user.CurrentLevel.String
if language == "" || level == "" {
return contextutils.ErrorWithContextf("user missing language or level preferences")
}
// Get user's daily goal from learning preferences
prefs, perr := s.learningService.GetUserLearningPreferences(ctx, userID)
if perr != nil {
span.RecordError(perr)
return contextutils.WrapError(perr, "failed to get user learning preferences")
}
goal := 10
if prefs != nil && prefs.DailyGoal > 0 {
goal = prefs.DailyGoal
}
// Request more candidates than strictly needed to allow filtering out already-assigned questions
buffer := 10 // request this many extra candidates beyond the user's goal
reqLimit := goal + buffer
// Get adaptive questions using an expanded limit so we can filter and still meet goal
questionsWithStats, err := s.questionService.GetAdaptiveQuestionsForDaily(ctx, userID, language, level, reqLimit)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get adaptive questions for assignment")
}
if len(questionsWithStats) == 0 {
// Gather diagnostics to explain why no questions were available
var candidateIDs []int
candidateCount := 0
totalMatching := 0
if s.questionService != nil {
if candidates, qerr := s.questionService.GetAdaptiveQuestionsForDaily(ctx, userID, language, level, 50); qerr == nil && candidates != nil {
candidateCount = len(candidates)
for i, q := range candidates {
if i >= 10 {
break
}
if q != nil {
candidateIDs = append(candidateIDs, q.ID)
}
}
}
if _, total, terr := s.questionService.GetAllQuestionsPaginated(ctx, 1, 1, "", "", "", language, level, nil); terr == nil {
totalMatching = total
}
}
return &NoQuestionsAvailableError{
Language: language,
Level: level,
CandidateIDs: candidateIDs,
CandidateCount: candidateCount,
TotalMatching: totalMatching,
}
}
// Convert QuestionWithStats to Question for assignment
var questions []models.Question
for _, qws := range questionsWithStats {
questions = append(questions, *qws.Question)
}
// Begin transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to begin transaction")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Error(ctx, "Failed to rollback transaction", rollbackErr, map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
})
}
}
}()
// First, delete existing assignments for this user and date
deleteQuery := `DELETE FROM daily_question_assignments WHERE user_id = $1 AND assignment_date = $2`
_, err = tx.ExecContext(ctx, deleteQuery, userID, date)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to delete existing assignments")
}
// Insert new assignments
insertQuery := `
INSERT INTO daily_question_assignments (user_id, question_id, assignment_date, created_at)
VALUES ($1, $2, $3, $4)
`
stmt, err := tx.PrepareContext(ctx, insertQuery)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to prepare statement")
}
defer func() {
if closeErr := stmt.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close statement", closeErr, map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
})
}
}()
// Only assign up to the goal amount
assignedCount := 0
for _, question := range questions {
if assignedCount >= goal {
break
}
_, err = stmt.ExecContext(ctx, userID, question.ID, date, time.Now())
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to insert assignment")
}
assignedCount++
}
// Commit transaction
err = tx.Commit()
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to commit transaction")
}
s.logger.Info(ctx, "Daily questions regenerated successfully", map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
"count": len(questions),
})
return nil
}
// GetDailyQuestions retrieves all daily questions for a user on a specific date
func (s *DailyQuestionService) GetDailyQuestions(ctx context.Context, userID int, date time.Time) (result0 []*models.DailyQuestionAssignmentWithQuestion, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "GetDailyQuestions",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT dqa.id, dqa.user_id, dqa.question_id, dqa.assignment_date,
dqa.is_completed, dqa.completed_at, dqa.created_at,
dqa.user_answer_index, dqa.submitted_at,
q.id, q.type, q.language, q.level, q.difficulty_score, q.content,
q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario,
q.style_modifier, q.difficulty_modifier, q.time_context,
-- Daily shown count per user: how many times this user has seen this question in Daily across all dates
(SELECT COUNT(*) FROM daily_question_assignments dqa_all WHERE dqa_all.question_id = dqa.question_id AND dqa_all.user_id = dqa.user_id) AS daily_shown_count,
-- Per-user correctness stats across all time
COALESCE((SELECT COUNT(*) FROM user_responses ur WHERE ur.user_id = dqa.user_id AND ur.question_id = dqa.question_id), 0) AS user_total_responses,
COALESCE((SELECT COUNT(*) FROM user_responses ur WHERE ur.user_id = dqa.user_id AND ur.question_id = dqa.question_id AND ur.is_correct = TRUE), 0) AS user_correct_count,
COALESCE((SELECT COUNT(*) FROM user_responses ur WHERE ur.user_id = dqa.user_id AND ur.question_id = dqa.question_id AND ur.is_correct = FALSE), 0) AS user_incorrect_count
FROM daily_question_assignments dqa
JOIN questions q ON dqa.question_id = q.id
WHERE dqa.user_id = $1 AND dqa.assignment_date = $2
ORDER BY dqa.created_at ASC
`
rows, err := s.db.QueryContext(ctx, query, userID, date)
if err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "failed to query daily questions")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
})
}
}()
var assignments []*models.DailyQuestionAssignmentWithQuestion
for rows.Next() {
var assignment models.DailyQuestionAssignmentWithQuestion
var question models.Question
var contentJSON string
err := rows.Scan(
&assignment.ID, &assignment.UserID, &assignment.QuestionID, &assignment.AssignmentDate,
&assignment.IsCompleted, &assignment.CompletedAt, &assignment.CreatedAt,
&assignment.UserAnswerIndex, &assignment.SubmittedAt,
&question.ID, &question.Type, &question.Language, &question.Level, &question.DifficultyScore,
&contentJSON, &question.CorrectAnswer, &question.Explanation, &question.CreatedAt, &question.Status,
&question.TopicCategory, &question.GrammarFocus, &question.VocabularyDomain, &question.Scenario,
&question.StyleModifier, &question.DifficultyModifier, &question.TimeContext,
&assignment.DailyShownCount,
&assignment.UserTotalResponses,
&assignment.UserCorrectCount,
&assignment.UserIncorrectCount,
)
if err != nil {
s.logger.Error(ctx, "Failed to scan daily question assignment", err, map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
})
continue
}
// Unmarshal the JSON content
if err := question.UnmarshalContentFromJSON(contentJSON); err != nil {
s.logger.Error(ctx, "Failed to unmarshal question content", err, map[string]interface{}{
"user_id": userID,
"date": date.Format("2006-01-02"),
"content": contentJSON,
})
continue
}
assignment.Question = &question
assignments = append(assignments, &assignment)
}
if err = rows.Err(); err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "error iterating over rows")
}
return assignments, nil
}
// MarkQuestionCompleted marks a daily question as completed
func (s *DailyQuestionService) MarkQuestionCompleted(ctx context.Context, userID, questionID int, date time.Time) (err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "MarkQuestionCompleted",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.Int("question.id", questionID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
UPDATE daily_question_assignments
SET is_completed = true, completed_at = $1
WHERE user_id = $2 AND question_id = $3 AND assignment_date = $4
`
result, err := s.db.ExecContext(ctx, query, time.Now(), userID, questionID, date)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to mark question as completed")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.ErrAssignmentNotFound
}
s.logger.Info(ctx, "Question marked as completed", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": date.Format("2006-01-02"),
})
return nil
}
// ResetQuestionCompleted resets a daily question to not completed
func (s *DailyQuestionService) ResetQuestionCompleted(ctx context.Context, userID, questionID int, date time.Time) (err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "ResetQuestionCompleted",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.Int("question.id", questionID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
UPDATE daily_question_assignments
SET is_completed = false, completed_at = NULL, user_answer_index = NULL, submitted_at = NULL
WHERE user_id = $1 AND question_id = $2 AND assignment_date = $3
`
result, err := s.db.ExecContext(ctx, query, userID, questionID, date)
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to reset question completion")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
span.RecordError(err)
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.ErrAssignmentNotFound
}
s.logger.Info(ctx, "Question reset to not completed", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": date.Format("2006-01-02"),
})
return nil
}
// GetAvailableDates retrieves all dates for which a user has daily question assignments
func (s *DailyQuestionService) GetAvailableDates(ctx context.Context, userID int) (result0 []time.Time, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "GetAvailableDates",
trace.WithAttributes(
attribute.Int("user.id", userID),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT DISTINCT assignment_date
FROM daily_question_assignments
WHERE user_id = $1
ORDER BY assignment_date DESC
`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "failed to query available dates")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{
"user_id": userID,
})
}
}()
var dates []time.Time
for rows.Next() {
var date time.Time
err := rows.Scan(&date)
if err != nil {
s.logger.Error(ctx, "Failed to scan date", err, map[string]interface{}{
"user_id": userID,
})
continue
}
dates = append(dates, date)
}
if err = rows.Err(); err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "error iterating over rows")
}
return dates, nil
}
// GetDailyProgress retrieves the progress for a specific date
func (s *DailyQuestionService) GetDailyProgress(ctx context.Context, userID int, date time.Time) (result0 *models.DailyProgress, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "GetDailyProgress",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
COUNT(*) as total,
COUNT(CASE WHEN is_completed = true THEN 1 END) as completed
FROM daily_question_assignments
WHERE user_id = $1 AND assignment_date = $2
`
var total, completed int
err = s.db.QueryRowContext(ctx, query, userID, date).Scan(&total, &completed)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get daily progress")
}
progress := &models.DailyProgress{
Date: date,
Completed: completed,
Total: total,
}
return progress, nil
}
// GetDailyQuestionsCount retrieves the total number of questions assigned for a date
func (s *DailyQuestionService) GetDailyQuestionsCount(ctx context.Context, userID int, date time.Time) (result0 int, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "GetDailyQuestionsCount",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT COUNT(*)
FROM daily_question_assignments
WHERE user_id = $1 AND assignment_date = $2
`
var count int
err = s.db.QueryRowContext(ctx, query, userID, date).Scan(&count)
if err != nil {
return 0, contextutils.WrapError(err, "failed to get daily questions count")
}
return count, nil
}
// GetCompletedDailyQuestionsCount retrieves the number of completed questions for a date
func (s *DailyQuestionService) GetCompletedDailyQuestionsCount(ctx context.Context, userID int, date time.Time) (result0 int, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "GetCompletedDailyQuestionsCount",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("date", date.Format("2006-01-02")),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT COUNT(*)
FROM daily_question_assignments
WHERE user_id = $1 AND assignment_date = $2 AND is_completed = true
`
var count int
err = s.db.QueryRowContext(ctx, query, userID, date).Scan(&count)
if err != nil {
return 0, contextutils.WrapError(err, "failed to get completed daily questions count")
}
return count, nil
}
// GetQuestionHistory retrieves the history of a specific question for a user over a given number of days
func (s *DailyQuestionService) GetQuestionHistory(ctx context.Context, userID, questionID, days int) (result0 []*models.DailyQuestionHistory, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "GetQuestionHistory",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.Int("question.id", questionID),
attribute.Int("days", days),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
if days <= 0 {
return nil, contextutils.ErrorWithContextf("days must be positive")
}
query := `
SELECT dqa.assignment_date, dqa.is_completed, dqa.submitted_at,
ur.is_correct
FROM daily_question_assignments dqa
LEFT JOIN daily_assignment_responses dar ON dar.assignment_id = dqa.id
LEFT JOIN user_responses ur ON ur.id = dar.user_response_id
WHERE dqa.user_id = $1 AND dqa.question_id = $2
AND dqa.assignment_date >= NOW() - INTERVAL '` + fmt.Sprintf("%d days", days) + `'
AND dqa.assignment_date <= CURRENT_DATE + INTERVAL '1 day'
ORDER BY dqa.assignment_date ASC
`
rows, err := s.db.QueryContext(ctx, query, userID, questionID)
if err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "failed to query question history")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"days": days,
})
}
}()
var history []*models.DailyQuestionHistory
for rows.Next() {
var historyEntry models.DailyQuestionHistory
var isCorrect sql.NullBool
err := rows.Scan(
&historyEntry.AssignmentDate,
&historyEntry.IsCompleted,
&historyEntry.SubmittedAt,
&isCorrect,
)
if err != nil {
s.logger.Error(ctx, "Failed to scan question history entry", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"assignment_date": historyEntry.AssignmentDate,
})
continue
}
if isCorrect.Valid {
historyEntry.IsCorrect = &isCorrect.Bool
} else {
historyEntry.IsCorrect = nil
}
history = append(history, &historyEntry)
}
if err = rows.Err(); err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "error iterating over rows")
}
return history, nil
}
// getUserByID is a helper method to get user information
func (s *DailyQuestionService) getUserByID(ctx context.Context, userID int) (*models.User, error) {
query := `
SELECT id, username, email, timezone, password_hash, last_active,
preferred_language, current_level, ai_provider, ai_model,
ai_enabled, ai_api_key, created_at, updated_at
FROM users
WHERE id = $1
`
var user models.User
err := s.db.QueryRowContext(ctx, query, userID).Scan(
&user.ID, &user.Username, &user.Email, &user.Timezone, &user.PasswordHash,
&user.LastActive, &user.PreferredLanguage, &user.CurrentLevel, &user.AIProvider,
&user.AIModel, &user.AIEnabled, &user.AIAPIKey, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return &user, nil
}
// SubmitDailyQuestionAnswer submits an answer for a daily question and marks it as completed
func (s *DailyQuestionService) SubmitDailyQuestionAnswer(ctx context.Context, userID, questionID int, date time.Time, userAnswerIndex int) (result *api.AnswerResponse, err error) {
ctx, span := otel.Tracer("daily-question-service").Start(ctx, "SubmitDailyQuestionAnswer",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.Int("question.id", questionID),
attribute.String("date", date.Format("2006-01-02")),
attribute.Int("user_answer_index", userAnswerIndex),
),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
s.logger.Info(ctx, "SubmitDailyQuestionAnswer started", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": date.Format("2006-01-02"),
"user_answer_index": userAnswerIndex,
})
// Check if the question is already answered
s.logger.Info(ctx, "Checking if question is already answered", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": date.Format("2006-01-02"),
})
query := `
SELECT id, is_completed, user_answer_index, submitted_at
FROM daily_question_assignments
WHERE user_id = $1 AND question_id = $2 AND assignment_date = $3
`
var assignmentID int
var isCompleted bool
var existingUserAnswerIndex *int
var existingSubmittedAt *time.Time
err = s.db.QueryRowContext(ctx, query, userID, questionID, date).Scan(
&assignmentID, &isCompleted, &existingUserAnswerIndex, &existingSubmittedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, contextutils.ErrAssignmentNotFound
}
return nil, contextutils.WrapError(err, "failed to check question assignment")
}
// Check if already answered
if isCompleted && existingUserAnswerIndex != nil && existingSubmittedAt != nil {
return nil, contextutils.ErrQuestionAlreadyAnswered
}
// Get the question details to validate answer and get correct answer
question, err := s.questionService.GetQuestionByID(ctx, questionID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get question details")
}
if question == nil {
return nil, contextutils.ErrQuestionNotFound
}
// Extract options from content map
contentMap := question.Content
s.logger.Info(ctx, "Question content debug", map[string]interface{}{
"question_id": questionID,
"content_map": contentMap,
})
optionsInterface, ok := contentMap["options"]
if !ok {
s.logger.Error(ctx, "Question content missing options", nil, map[string]interface{}{
"question_id": questionID,
"content_map": contentMap,
})
return nil, contextutils.ErrorWithContextf("question content missing options")
}
options, ok := optionsInterface.([]interface{})
if !ok {
s.logger.Error(ctx, "Invalid options format", nil, map[string]interface{}{
"question_id": questionID,
"options_interface": optionsInterface,
"options_type": fmt.Sprintf("%T", optionsInterface),
})
return nil, contextutils.ErrorWithContextf("invalid options format")
}
// Validate user answer index
if userAnswerIndex < 0 || userAnswerIndex >= len(options) {
return nil, contextutils.ErrInvalidAnswerIndex
}
// Check if answer is correct
isCorrect := question.CorrectAnswer == userAnswerIndex
// Begin transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, contextutils.WrapError(err, "failed to begin transaction")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Error(ctx, "Failed to rollback transaction", rollbackErr, map[string]interface{}{
"error": rollbackErr.Error(),
})
}
}
}()
// Update the assignment with the user's answer and mark as completed
updateQuery := `
UPDATE daily_question_assignments
SET is_completed = true, completed_at = NOW(), user_answer_index = $1, submitted_at = NOW()
WHERE id = $2
`
_, err = tx.ExecContext(ctx, updateQuery, userAnswerIndex, assignmentID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to update assignment")
}
// Commit transaction
err = tx.Commit()
if err != nil {
return nil, contextutils.WrapError(err, "failed to commit transaction")
}
// Record canonical user response via learningService so history queries see is_correct
// Use RecordAnswerWithPriorityReturningID to obtain user_responses.id so we can link it to the assignment.
if s.learningService != nil {
// record synchronously so we have the response id for mapping
respID, recErr := s.learningService.RecordAnswerWithPriorityReturningID(ctx, userID, questionID, userAnswerIndex, isCorrect, 0)
if recErr != nil {
s.logger.Error(ctx, "Failed to record user response for daily answer", recErr, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"user_answer_index": userAnswerIndex,
})
} else {
// Insert mapping to daily_assignment_responses synchronously so tests that run immediately can observe it
_, mapErr := s.db.ExecContext(ctx, `
INSERT INTO daily_assignment_responses (assignment_id, user_response_id, created_at)
VALUES ($1, $2, NOW())
ON CONFLICT (assignment_id) DO UPDATE SET user_response_id = EXCLUDED.user_response_id, created_at = EXCLUDED.created_at
`, assignmentID, respID)
if mapErr != nil {
// Log but don't fail user's request
s.logger.Error(ctx, "Failed to insert daily_assignment_responses mapping", mapErr, map[string]interface{}{
"assignment_id": assignmentID,
"user_response_id": respID,
})
}
// If the answer was correct, remove future assignments for this question within the avoid window
if isCorrect {
// Determine avoidDays via questionService if possible; default to 7
avoidDays := 7
switch qs := s.questionService.(type) {
case interface{ getDailyRepeatAvoidDays() int }:
avoidDays = qs.getDailyRepeatAvoidDays()
default:
// leave default
}
startDate := date.AddDate(0, 0, 1)
endDate := date.AddDate(0, 0, avoidDays)
deleteQuery := `DELETE FROM daily_question_assignments WHERE user_id = $1 AND question_id = $2 AND assignment_date >= $3 AND assignment_date <= $4`
if _, delErr := s.db.ExecContext(ctx, deleteQuery, userID, questionID, startDate, endDate); delErr != nil {
s.logger.Error(ctx, "Failed to delete future daily assignments", delErr, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"start": startDate,
"end": endDate,
})
} else {
// Future assignments removed successfully; worker will top up missing slots on its next run
s.logger.Info(ctx, "Deleted future daily assignments for question; worker will refill dates as needed", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"start": startDate,
"end": endDate,
})
}
}
}
}
// Build response
userAnswer := options[userAnswerIndex].(string)
response := &api.AnswerResponse{
UserAnswerIndex: &userAnswerIndex,
UserAnswer: &userAnswer,
IsCorrect: &isCorrect,
}
// Add correct answer and explanation if available
response.CorrectAnswerIndex = &question.CorrectAnswer
if question.Explanation != "" {
response.Explanation = &question.Explanation
}
s.logger.Info(ctx, "Daily question answer submitted", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"date": date.Format("2006-01-02"),
"user_answer_index": userAnswerIndex,
"is_correct": isCorrect,
})
return response, nil
}
// Package services provides business logic services for the quiz application.
package services
import (
"context"
"database/sql"
"quizapp/internal/config"
"quizapp/internal/observability"
"quizapp/internal/services/mailer"
)
// CreateEmailService creates an appropriate email service based on configuration
// If the application is running in test mode, it returns a TestEmailService
// Otherwise, it returns the regular EmailService
func CreateEmailService(cfg *config.Config, logger *observability.Logger) mailer.Mailer {
if cfg.IsTest {
logger.Info(context.Background(), "Using test email service", map[string]interface{}{
"test_mode": true,
})
return NewTestEmailService(cfg, logger)
}
return NewEmailService(cfg, logger)
}
// CreateEmailServiceWithDB creates an appropriate email service with database connection based on configuration
// If the application is running in test mode, it returns a TestEmailService
// Otherwise, it returns the regular EmailService
func CreateEmailServiceWithDB(cfg *config.Config, logger *observability.Logger, db *sql.DB) mailer.Mailer {
if cfg.IsTest {
logger.Info(context.Background(), "Using test email service with DB", map[string]interface{}{
"test_mode": true,
})
return NewTestEmailServiceWithDB(cfg, logger, db)
}
if db == nil {
logger.Error(context.Background(), "Database connection is nil, cannot create EmailService", nil, map[string]interface{}{
"error": "nil_database_connection",
})
panic("EmailService requires a non-nil database connection")
}
return NewEmailServiceWithDB(cfg, logger, db)
}
// Package services provides business logic services for the quiz application.
package services
import (
"context"
"database/sql"
"fmt"
"html/template"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
serviceinterfaces "quizapp/internal/services/interfaces"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"gopkg.in/mail.v2"
)
// EmailService implements the interfaces.EmailService interface using gomail
type EmailService struct {
cfg *config.Config
logger *observability.Logger
dialer *mail.Dialer
db *sql.DB
}
// EmailServiceInterface defines the interface for email functionality
type EmailServiceInterface = serviceinterfaces.EmailService
// Ensure EmailService implements the EmailServiceInterface
var _ serviceinterfaces.EmailService = (*EmailService)(nil)
// NewEmailService creates a new EmailService instance
func NewEmailService(cfg *config.Config, logger *observability.Logger) *EmailService {
var dialer *mail.Dialer
if cfg.Email.Enabled && cfg.Email.SMTP.Host != "" {
dialer = mail.NewDialer(
cfg.Email.SMTP.Host,
cfg.Email.SMTP.Port,
cfg.Email.SMTP.Username,
cfg.Email.SMTP.Password,
)
}
return &EmailService{
cfg: cfg,
logger: logger,
dialer: dialer,
}
}
// NewEmailServiceWithDB creates a new EmailService instance with database connection
func NewEmailServiceWithDB(cfg *config.Config, logger *observability.Logger, db *sql.DB) *EmailService {
if db == nil {
panic("EmailService requires a non-nil database connection")
}
var dialer *mail.Dialer
if cfg.Email.Enabled && cfg.Email.SMTP.Host != "" {
dialer = mail.NewDialer(
cfg.Email.SMTP.Host,
cfg.Email.SMTP.Port,
cfg.Email.SMTP.Username,
cfg.Email.SMTP.Password,
)
}
return &EmailService{
cfg: cfg,
logger: logger,
dialer: dialer,
db: db,
}
}
// SendDailyReminder sends a daily reminder email to a user
func (e *EmailService) SendDailyReminder(ctx context.Context, user *models.User) (err error) {
ctx, span := otel.Tracer("email-service").Start(ctx, "SendDailyReminder",
trace.WithAttributes(
attribute.Int("user.id", user.ID),
attribute.String("user.email", user.Email.String),
),
)
defer observability.FinishSpan(span, &err)
if !e.IsEnabled() {
e.logger.Info(ctx, "Email disabled, skipping daily reminder", map[string]interface{}{
"user_id": user.ID,
"email": user.Email.String,
})
return nil
}
if !user.Email.Valid || user.Email.String == "" {
e.logger.Warn(ctx, "User has no email address, skipping daily reminder", map[string]interface{}{
"user_id": user.ID,
})
return nil
}
// Determine daily goal from DB
dailyGoal := 10
var dg sql.NullInt64
if err := e.db.QueryRowContext(ctx, "SELECT daily_goal FROM user_learning_preferences WHERE user_id = $1", user.ID).Scan(&dg); err == nil && dg.Valid {
dailyGoal = int(dg.Int64)
}
// Generate email data
data := map[string]interface{}{
"Username": user.Username,
"QuizAppURL": e.cfg.Server.AppBaseURL, // Frontend app URL for email links
"CurrentDate": time.Now().Format("January 2, 2006"),
"DailyGoal": dailyGoal,
"UnsubscribeURL": fmt.Sprintf("%s/settings", e.cfg.Server.AppBaseURL),
}
subject := "Time for your daily quiz! ð"
err = e.SendEmail(ctx, user.Email.String, subject, "daily_reminder", data)
if err != nil {
return contextutils.WrapError(err, "failed to send daily reminder")
}
e.logger.Info(ctx, "Daily reminder sent successfully", map[string]interface{}{
"user_id": user.ID,
"email": user.Email.String,
})
return nil
}
// SendEmail sends a generic email with the given parameters
func (e *EmailService) SendEmail(ctx context.Context, to, subject, templateName string, data map[string]interface{}) (err error) {
ctx, span := otel.Tracer("email-service").Start(ctx, "SendEmail",
trace.WithAttributes(
attribute.String("email.to", to),
attribute.String("email.subject", subject),
attribute.String("email.template", templateName),
),
)
defer observability.FinishSpan(span, &err)
if !e.IsEnabled() {
e.logger.Info(ctx, "Email disabled, skipping email send", map[string]interface{}{
"to": to,
"template": templateName,
})
return nil
}
if e.dialer == nil {
return contextutils.ErrorWithContextf("email service not properly configured")
}
// Create email message
m := mail.NewMessage()
m.SetHeader("From", fmt.Sprintf("%s <%s>", e.cfg.Email.SMTP.FromName, e.cfg.Email.SMTP.FromAddress))
m.SetHeader("To", to)
m.SetHeader("Subject", subject)
// Generate email content from template
content, err := e.generateEmailContent(templateName, data)
if err != nil {
return contextutils.WrapError(err, "failed to generate email content")
}
m.SetBody("text/html", content)
// Send email
if err = e.dialer.DialAndSend(m); err != nil {
e.logger.Error(ctx, "Failed to send email", err, map[string]interface{}{
"to": to,
"template": templateName,
"subject": subject,
})
return contextutils.WrapError(err, "failed to send email")
}
e.logger.Info(ctx, "Email sent successfully", map[string]interface{}{
"to": to,
"template": templateName,
"subject": subject,
})
return nil
}
// RecordSentNotification records a sent notification in the database
func (e *EmailService) RecordSentNotification(ctx context.Context, userID int, notificationType, subject, templateName, status, errorMessage string) (err error) {
ctx, span := otel.Tracer("email-service").Start(ctx, "RecordSentNotification",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("notification.type", notificationType),
attribute.String("notification.status", status),
),
)
defer observability.FinishSpan(span, &err)
if e.db == nil {
e.logger.Error(ctx, "Database connection is nil, cannot record notification", nil, map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
})
return contextutils.ErrorWithContextf("EmailService database connection is nil")
}
query := `
INSERT INTO sent_notifications (user_id, notification_type, subject, template_name, sent_at, status, error_message)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`
_, err = e.db.ExecContext(ctx, query, userID, notificationType, subject, templateName, time.Now(), status, errorMessage)
if err != nil {
e.logger.Error(ctx, "Failed to record sent notification", err, map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
"status": status,
})
return contextutils.WrapError(err, "failed to record sent notification")
}
e.logger.Info(ctx, "Recorded sent notification", map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
"status": status,
})
return nil
}
// IsEnabled returns whether email functionality is enabled
func (e *EmailService) IsEnabled() bool {
return e.cfg.Email.Enabled && e.cfg.Email.SMTP.Host != ""
}
// generateEmailContent generates email content from templates
func (e *EmailService) generateEmailContent(templateName string, data map[string]interface{}) (string, error) {
// For now, we'll use a simple template system
// In a real implementation, you might load templates from files or database
switch templateName {
case "daily_reminder":
return e.generateDailyReminderTemplate(data)
case "test_email":
return e.generateTestEmailTemplate(data)
default:
return "", contextutils.ErrorWithContextf("unknown template: %s", templateName)
}
}
// generateDailyReminderTemplate generates the daily reminder email template
func (e *EmailService) generateDailyReminderTemplate(data map[string]interface{}) (string, error) {
const templateStr = `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Daily Quiz Reminder</title>
<style>
body { font-family: Arial, sans-serif; line-height: 1.6; color: #333; }
.container { max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background-color: #4CAF50; color: white; padding: 20px; text-align: center; border-radius: 5px 5px 0 0; }
.content { background-color: #f9f9f9; padding: 20px; }
.button { display: inline-block; background-color: #4CAF50; color: white; padding: 12px 24px; text-decoration: none; border-radius: 5px; margin: 20px 0; }
.footer { background-color: #eee; padding: 15px; text-align: center; font-size: 12px; color: #666; border-radius: 0 0 5px 5px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>ð Daily Quiz Reminder</h1>
</div>
<div class="content">
<h2>Hello {{.Username}}!</h2>
<p>It's {{.CurrentDate}} and time for your daily questions!</p>
<p>Your goal today: <strong>{{.DailyGoal}} questions</strong></p>
<p>Keep up the great work and continue improving your language skills!</p>
<div style="text-align: center;">
<a href="{{.QuizAppURL}}/daily" class="button">Start Your Daily Questions</a>
</div>
</div>
<div class="footer">
<p>This email was sent by Quiz App. If you no longer wish to receive these reminders, you can <a href="{{.UnsubscribeURL}}">unsubscribe here</a>.</p>
</div>
</div>
</body>
</html>`
tmpl, err := template.New("daily_reminder").Parse(templateStr)
if err != nil {
return "", contextutils.WrapError(err, "failed to parse template")
}
var buf strings.Builder
if err := tmpl.Execute(&buf, data); err != nil {
return "", contextutils.WrapError(err, "failed to execute template")
}
return buf.String(), nil
}
// generateTestEmailTemplate generates the test email template
func (e *EmailService) generateTestEmailTemplate(data map[string]interface{}) (string, error) {
const templateStr = `
<!DOCTYPE html>
<html>
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Test Email</title>
<style>
body { font-family: Arial, sans-serif; line-height: 1.6; color: #333; }
.container { max-width: 600px; margin: 0 auto; padding: 20px; }
.header { background-color: #2196F3; color: white; padding: 20px; text-align: center; border-radius: 5px 5px 0 0; }
.content { background-color: #f9f9f9; padding: 20px; }
.footer { background-color: #eee; padding: 15px; text-align: center; font-size: 12px; color: #666; border-radius: 0 0 5px 5px; }
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>ð Test Email</h1>
</div>
<div class="content">
<h2>Hello {{.Username}}!</h2>
<p>This is a test email to verify that your email settings are working correctly.</p>
<p><strong>Test Time:</strong> {{.TestTime}}</p>
<p><strong>Message:</strong> {{.Message}}</p>
<p>If you received this email, your email configuration is working properly!</p>
</div>
<div class="footer">
<p>This is a test email from Quiz App. No action is required.</p>
</div>
</div>
</body>
</html>
`
tmpl, err := template.New("test_email").Parse(templateStr)
if err != nil {
return "", contextutils.WrapError(err, "failed to parse template")
}
var buf strings.Builder
if err := tmpl.Execute(&buf, data); err != nil {
return "", contextutils.WrapError(err, "failed to execute template")
}
return buf.String(), nil
}
package services
import (
"context"
"database/sql"
"time"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
)
// GenerationHint represents an active generation hint
type GenerationHint struct {
ID int `db:"id"`
UserID int `db:"user_id"`
Language string `db:"language"`
Level string `db:"level"`
QuestionType string `db:"question_type"`
PriorityWeight int `db:"priority_weight"`
ExpiresAt time.Time `db:"expires_at"`
CreatedAt time.Time `db:"created_at"`
}
// GenerationHintServiceInterface defines the API for managing generation hints
type GenerationHintServiceInterface interface {
UpsertHint(ctx context.Context, userID int, language, level string, qType models.QuestionType, ttl time.Duration) error
GetActiveHintsForUser(ctx context.Context, userID int) ([]GenerationHint, error)
ClearHint(ctx context.Context, userID int, language, level string, qType models.QuestionType) error
}
// GenerationHintService implements hint management
type GenerationHintService struct {
db *sql.DB
logger *observability.Logger
}
// NewGenerationHintService constructs a service for managing short-lived per-user
// generation hints that nudge the worker to prioritize specific question types
// (e.g., reading comprehension) when the user is waiting for generation.
func NewGenerationHintService(db *sql.DB, logger *observability.Logger) *GenerationHintService {
return &GenerationHintService{db: db, logger: logger}
}
// UpsertHint creates or refreshes a hint with the given TTL
func (s *GenerationHintService) UpsertHint(ctx context.Context, userID int, language, level string, qType models.QuestionType, ttl time.Duration) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "upsert_generation_hint")
defer observability.FinishSpan(span, &err)
expiresAt := time.Now().Add(ttl)
_, err = s.db.ExecContext(ctx, `
INSERT INTO generation_hints (user_id, language, level, question_type, priority_weight, expires_at)
VALUES ($1, $2, $3, $4, 1, $5)
ON CONFLICT (user_id, language, level, question_type) DO UPDATE SET
priority_weight = generation_hints.priority_weight + 1,
expires_at = EXCLUDED.expires_at,
created_at = generation_hints.created_at
`, userID, language, level, string(qType), expiresAt)
if err != nil {
return contextutils.WrapError(err, "failed to upsert generation hint")
}
return nil
}
// GetActiveHintsForUser returns non-expired hints for the user
func (s *GenerationHintService) GetActiveHintsForUser(ctx context.Context, userID int) (result0 []GenerationHint, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_active_generation_hints")
defer observability.FinishSpan(span, &err)
rows, err := s.db.QueryContext(ctx, `
SELECT id, user_id, language, level, question_type, priority_weight, expires_at, created_at
FROM generation_hints
WHERE user_id = $1 AND expires_at > NOW()
ORDER BY created_at ASC
`, userID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query generation hints")
}
defer func() { _ = rows.Close() }()
var hints []GenerationHint
for rows.Next() {
var h GenerationHint
if err := rows.Scan(&h.ID, &h.UserID, &h.Language, &h.Level, &h.QuestionType, &h.PriorityWeight, &h.ExpiresAt, &h.CreatedAt); err != nil {
return nil, contextutils.WrapError(err, "failed to scan generation hint")
}
hints = append(hints, h)
}
if err := rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating generation hints")
}
return hints, nil
}
// ClearHint deletes a specific hint
func (s *GenerationHintService) ClearHint(ctx context.Context, userID int, language, level string, qType models.QuestionType) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "clear_generation_hint")
defer observability.FinishSpan(span, &err)
_, err = s.db.ExecContext(ctx, `
DELETE FROM generation_hints
WHERE user_id = $1 AND language = $2 AND level = $3 AND question_type = $4
`, userID, language, level, string(qType))
if err != nil {
return contextutils.WrapError(err, "failed to clear generation hint")
}
return nil
}
package services
import (
"context"
"database/sql"
"fmt"
"math"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"github.com/lib/pq"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// LearningServiceInterface defines the interface for the learning service
type LearningServiceInterface interface {
RecordUserResponse(ctx context.Context, response *models.UserResponse) error
GetUserProgress(ctx context.Context, userID int) (*models.UserProgress, error)
GetWeakestTopics(ctx context.Context, userID, limit int) ([]*models.PerformanceMetrics, error)
ShouldAvoidQuestion(ctx context.Context, userID, questionID int) (bool, error)
GetUserQuestionStats(ctx context.Context, userID int) (*UserQuestionStats, error)
// Priority system methods
RecordAnswerWithPriority(ctx context.Context, userID, questionID, answerIndex int, isCorrect bool, responseTime int) error
// RecordAnswerWithPriorityReturningID records the response and returns the created user_responses.id
RecordAnswerWithPriorityReturningID(ctx context.Context, userID, questionID, answerIndex int, isCorrect bool, responseTime int) (int, error)
MarkQuestionAsKnown(ctx context.Context, userID, questionID int, confidenceLevel *int) error
GetUserLearningPreferences(ctx context.Context, userID int) (*models.UserLearningPreferences, error)
UpdateLastDailyReminderSent(ctx context.Context, userID int) error
CalculatePriorityScore(ctx context.Context, userID, questionID int) (float64, error)
UpdateUserLearningPreferences(ctx context.Context, userID int, prefs *models.UserLearningPreferences) (*models.UserLearningPreferences, error)
GetUserQuestionConfidenceLevel(ctx context.Context, userID, questionID int) (*int, error)
// Analytics methods
GetPriorityScoreDistribution(ctx context.Context) (map[string]interface{}, error)
GetHighPriorityQuestions(ctx context.Context, limit int) ([]map[string]interface{}, error)
GetWeakAreasByTopic(ctx context.Context, limit int) ([]map[string]interface{}, error)
GetLearningPreferencesUsage(ctx context.Context) (map[string]interface{}, error)
GetQuestionTypeGaps(ctx context.Context) ([]map[string]interface{}, error)
GetGenerationSuggestions(ctx context.Context) ([]map[string]interface{}, error)
GetPrioritySystemPerformance(ctx context.Context) (map[string]interface{}, error)
GetBackgroundJobsStatus(ctx context.Context) (map[string]interface{}, error)
// User-specific analytics methods
GetUserPriorityScoreDistribution(ctx context.Context, userID int) (map[string]interface{}, error)
GetUserHighPriorityQuestions(ctx context.Context, userID, limit int) ([]map[string]interface{}, error)
GetUserWeakAreas(ctx context.Context, userID, limit int) ([]map[string]interface{}, error)
// Additional analytics methods for progress API
GetHighPriorityTopics(ctx context.Context, userID int) ([]string, error)
GetGapAnalysis(ctx context.Context, userID int) (map[string]interface{}, error)
GetPriorityDistribution(ctx context.Context, userID int) (map[string]int, error)
}
// UserQuestionStats represents per-user question statistics
type UserQuestionStats struct {
UserID int `json:"user_id"`
TotalAnswered int `json:"total_answered"`
CorrectAnswers int `json:"correct_answers"`
IncorrectAnswers int `json:"incorrect_answers"`
AccuracyRate float64 `json:"accuracy_rate"`
AnsweredByType map[string]int `json:"answered_by_type"`
AnsweredByLevel map[string]int `json:"answered_by_level"`
AccuracyByType map[string]float64 `json:"accuracy_by_type"`
AccuracyByLevel map[string]float64 `json:"accuracy_by_level"`
AvailableByType map[string]int `json:"available_by_type"`
AvailableByLevel map[string]int `json:"available_by_level"`
RecentlyAnswered int `json:"recently_answered"` // Within last hour
}
// contextutils.ErrQuestionNotFound is returned when a question does not exist in the database
// contextutils.ErrQuestionNotFound is now imported from contextutils
// LearningService provides methods for managing user learning progress
type LearningService struct {
db *sql.DB
cfg *config.Config
logger *observability.Logger
}
// NewLearningServiceWithLogger creates a new LearningService with a logger
func NewLearningServiceWithLogger(db *sql.DB, cfg *config.Config, logger *observability.Logger) *LearningService {
return &LearningService{
db: db,
cfg: cfg,
logger: logger,
}
}
// RecordUserResponse records a user's response to a question and updates metrics
func (s *LearningService) RecordUserResponse(ctx context.Context, response *models.UserResponse) (err error) {
ctx, span := observability.TraceLearningFunction(ctx, "record_user_response",
observability.AttributeUserID(response.UserID),
observability.AttributeQuestionID(response.QuestionID),
attribute.Bool("response.is_correct", response.IsCorrect),
attribute.Int("response.time_ms", response.ResponseTimeMs),
)
defer observability.FinishSpan(span, &err)
query := `
INSERT INTO user_responses (user_id, question_id, user_answer_index, is_correct, response_time_ms)
VALUES ($1, $2, $3, $4, $5) RETURNING id
`
var id int
err = s.db.QueryRowContext(ctx, query,
response.UserID,
response.QuestionID,
response.UserAnswerIndex,
response.IsCorrect,
response.ResponseTimeMs,
).Scan(&id)
if err != nil {
return err
}
response.ID = id
// Update performance metrics
return s.updatePerformanceMetrics(ctx, response)
}
func (s *LearningService) updatePerformanceMetrics(ctx context.Context, response *models.UserResponse) (err error) {
ctx, span := observability.TraceLearningFunction(ctx, "update_performance_metrics",
observability.AttributeUserID(response.UserID),
observability.AttributeQuestionID(response.QuestionID),
attribute.Bool("response.is_correct", response.IsCorrect),
)
defer observability.FinishSpan(span, &err)
// Get question details
var question *models.Question
question, err = s.getQuestionDetails(ctx, response.QuestionID)
if err != nil {
return err
}
// Update or create performance metrics
query := `
INSERT INTO performance_metrics (
user_id, topic, language, level, total_attempts, correct_attempts,
average_response_time_ms, difficulty_adjustment, last_updated
)
VALUES ($1, $2, $3, $4, 1, $5, $6, 0.0, CURRENT_TIMESTAMP)
ON CONFLICT(user_id, topic, language, level) DO UPDATE SET
total_attempts = performance_metrics.total_attempts + 1,
correct_attempts = performance_metrics.correct_attempts + $7,
average_response_time_ms = (performance_metrics.average_response_time_ms * (performance_metrics.total_attempts - 1) + $8) / performance_metrics.total_attempts,
last_updated = CURRENT_TIMESTAMP
`
correctIncrement := 0
if response.IsCorrect {
correctIncrement = 1
}
_, err = s.db.ExecContext(ctx, query,
response.UserID,
question.TopicCategory,
question.Language,
question.Level,
correctIncrement, // For initial correct_attempts in VALUES
float64(response.ResponseTimeMs), // For initial average_response_time_ms in VALUES
correctIncrement, // For correct_attempts increment in UPDATE
response.ResponseTimeMs, // For average_response_time_ms calculation in UPDATE
)
return err
}
// getUserByID is a lightweight helper for LearningService to fetch a user row.
func (s *LearningService) getUserByID(ctx context.Context, userID int) (*models.User, error) {
query := `
SELECT id, username, email, timezone, password_hash, last_active,
preferred_language, current_level, ai_provider, ai_model,
ai_enabled, ai_api_key, created_at, updated_at
FROM users
WHERE id = $1
`
var u models.User
err := s.db.QueryRowContext(ctx, query, userID).Scan(
&u.ID, &u.Username, &u.Email, &u.Timezone, &u.PasswordHash, &u.LastActive,
&u.PreferredLanguage, &u.CurrentLevel, &u.AIProvider, &u.AIModel,
&u.AIEnabled, &u.AIAPIKey, &u.CreatedAt, &u.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
return &u, nil
}
func (s *LearningService) getQuestionDetails(ctx context.Context, questionID int) (result0 *models.Question, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_question_details",
observability.AttributeQuestionID(questionID),
)
defer observability.FinishSpan(span, &err)
query := `SELECT type, language, level, topic_category FROM questions WHERE id = $1`
question := &models.Question{}
var topicCategory sql.NullString
err = s.db.QueryRowContext(ctx, query, questionID).Scan(
&question.Type,
&question.Language,
&question.Level,
&topicCategory,
)
if topicCategory.Valid {
question.TopicCategory = topicCategory.String
}
return question, err
}
// GetUserProgress retrieves comprehensive learning progress for a user
func (s *LearningService) GetUserProgress(ctx context.Context, userID int) (result0 *models.UserProgress, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_progress",
attribute.String("user.username", ""),
attribute.String("language", ""),
attribute.String("level", ""),
)
defer observability.FinishSpan(span, &err)
progress := &models.UserProgress{
PerformanceByTopic: make(map[string]*models.PerformanceMetrics),
}
// Get overall stats
overallQuery := `
SELECT
COUNT(*) as total,
COALESCE(SUM(CASE WHEN is_correct THEN 1 ELSE 0 END), 0) as correct
FROM user_responses
WHERE user_id = $1
`
err = s.db.QueryRowContext(ctx, overallQuery, userID).Scan(
&progress.TotalQuestions,
&progress.CorrectAnswers,
)
if err != nil && err != sql.ErrNoRows {
return nil, err
}
if progress.TotalQuestions > 0 {
progress.AccuracyRate = float64(progress.CorrectAnswers) / float64(progress.TotalQuestions) * 100
}
// Get performance by topic
metricsQuery := `
SELECT id, topic, language, level, total_attempts, correct_attempts,
average_response_time_ms, difficulty_adjustment, last_updated
FROM performance_metrics
WHERE user_id = $1
`
rows, err := s.db.QueryContext(ctx, metricsQuery, userID)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
for rows.Next() {
metric := &models.PerformanceMetrics{UserID: userID}
err = rows.Scan(
&metric.ID,
&metric.Topic,
&metric.Language,
&metric.Level,
&metric.TotalAttempts,
&metric.CorrectAttempts,
&metric.AverageResponseTimeMs,
&metric.DifficultyAdjustment,
&metric.LastUpdated,
)
if err != nil {
return nil, err
}
key := metric.Topic + "_" + metric.Language + "_" + metric.Level
progress.PerformanceByTopic[key] = metric
}
// Identify weak areas (accuracy < 60%)
progress.WeakAreas = s.identifyWeakAreas(progress.PerformanceByTopic)
// Get recent activity
progress.RecentActivity, err = s.getRecentActivity(ctx, userID, 10)
if err != nil {
return nil, err
}
// Get current level from user
currentLevel, err := s.getCurrentUserLevel(ctx, userID)
if err != nil {
return nil, err
}
progress.CurrentLevel = currentLevel
// Suggest level adjustment if needed
progress.SuggestedLevel = s.suggestLevelAdjustment(progress)
return progress, nil
}
func (s *LearningService) identifyWeakAreas(metrics map[string]*models.PerformanceMetrics) []string {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
// But we could add tracing if we want to track the analysis performance
var weakAreas []string
for key, metric := range metrics {
if metric.TotalAttempts > 0 && metric.AccuracyRate() < 60.0 && metric.TotalAttempts >= 3 {
weakAreas = append(weakAreas, key)
}
}
return weakAreas
}
func (s *LearningService) getRecentActivity(ctx context.Context, userID, limit int) (result0 []models.UserResponse, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_recent_activity",
observability.AttributeUserID(userID),
attribute.Int("limit", limit),
)
defer observability.FinishSpan(span, &err)
query := `
SELECT id, user_id, question_id, user_answer_index, is_correct, response_time_ms, created_at
FROM user_responses
WHERE user_id = $1
ORDER BY created_at DESC
LIMIT $2
`
rows, err := s.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var responses []models.UserResponse
for rows.Next() {
var response models.UserResponse
err = rows.Scan(
&response.ID,
&response.UserID,
&response.QuestionID,
&response.UserAnswerIndex,
&response.IsCorrect,
&response.ResponseTimeMs,
&response.CreatedAt,
)
if err != nil {
return nil, err
}
responses = append(responses, response)
}
return responses, nil
}
func (s *LearningService) getCurrentUserLevel(ctx context.Context, userID int) (result0 string, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_current_user_level",
observability.AttributeUserID(userID),
)
defer observability.FinishSpan(span, &err)
query := `SELECT current_level FROM users WHERE id = $1`
var level sql.NullString
err = s.db.QueryRowContext(ctx, query, userID).Scan(&level)
if err != nil {
return "", err
}
// Return default level if NULL
if !level.Valid || level.String == "" {
return "A1", nil // Default level
}
return level.String, nil
}
func (s *LearningService) suggestLevelAdjustment(progress *models.UserProgress) string {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
// But we could add tracing if we want to track the analysis performance
if progress.TotalQuestions < 20 {
return "" // Not enough data
}
// If accuracy is consistently high (>85%), suggest level up
if progress.AccuracyRate > 85.0 {
return s.getNextLevel(progress.CurrentLevel)
}
// If accuracy is consistently low (<50%), suggest level down
if progress.AccuracyRate < 50.0 {
return s.getPreviousLevel(progress.CurrentLevel)
}
return ""
}
func (s *LearningService) getNextLevel(currentLevel string) string {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
levels := s.cfg.GetAllLevels()
for i, level := range levels {
if level == currentLevel && i < len(levels)-1 {
return levels[i+1]
}
}
return currentLevel
}
func (s *LearningService) getPreviousLevel(currentLevel string) string {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
levels := s.cfg.GetAllLevels()
for i, level := range levels {
if level == currentLevel && i > 0 {
return levels[i-1]
}
}
return currentLevel
}
// GetWeakestTopics returns the topics where the user performs poorest
func (s *LearningService) GetWeakestTopics(ctx context.Context, userID, limit int) (result0 []*models.PerformanceMetrics, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_weakest_topics",
observability.AttributeUserID(userID),
attribute.Int("limit", limit),
)
defer observability.FinishSpan(span, &err)
query := `
SELECT id, topic, language, level, total_attempts, correct_attempts, average_response_time_ms, difficulty_adjustment, last_updated
FROM performance_metrics
WHERE user_id = $1 AND total_attempts >= 3
ORDER BY (correct_attempts * 1.0 / total_attempts) ASC, last_updated ASC
LIMIT $2
`
rows, err := s.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var topics []*models.PerformanceMetrics
for rows.Next() {
metric := &models.PerformanceMetrics{UserID: userID}
err = rows.Scan(
&metric.ID,
&metric.Topic,
&metric.Language,
&metric.Level,
&metric.TotalAttempts,
&metric.CorrectAttempts,
&metric.AverageResponseTimeMs,
&metric.DifficultyAdjustment,
&metric.LastUpdated,
)
if err != nil {
return nil, err
}
topics = append(topics, metric)
}
return topics, nil
}
// ShouldAvoidQuestion determines if a question should be avoided for a user
func (s *LearningService) ShouldAvoidQuestion(ctx context.Context, userID, questionID int) (result0 bool, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "should_avoid_question",
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
defer observability.FinishSpan(span, &err)
// Determine user's local 1-day window and convert to UTC timestamps
startUTC, endUTC, _, err := contextutils.UserLocalDayRange(ctx, userID, 1, s.getUserByID)
if err != nil {
return false, contextutils.WrapError(err, "failed to compute user local day range")
}
query := `
SELECT COUNT(*)
FROM user_responses
WHERE user_id = $1 AND question_id = $2 AND is_correct = true
AND created_at >= $3 AND created_at < $4
`
var count int
err = s.db.QueryRowContext(ctx, query, userID, questionID, startUTC, endUTC).Scan(&count)
span.SetAttributes(attribute.Bool("should_avoid", count > 0))
return count > 0, err
}
// GetUserQuestionStats returns comprehensive per-user question statistics
func (s *LearningService) GetUserQuestionStats(ctx context.Context, userID int) (result0 *UserQuestionStats, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_question_stats",
observability.AttributeUserID(userID),
)
defer observability.FinishSpan(span, &err)
stats := &UserQuestionStats{
UserID: userID,
AnsweredByType: make(map[string]int),
AnsweredByLevel: make(map[string]int),
AccuracyByType: make(map[string]float64),
AccuracyByLevel: make(map[string]float64),
AvailableByType: make(map[string]int),
AvailableByLevel: make(map[string]int),
}
// Get user's language and level preferences
var userLanguage, userLevel string
userQuery := `SELECT COALESCE(preferred_language, 'italian'), COALESCE(current_level, 'B1') FROM users WHERE id = $1`
err = s.db.QueryRowContext(ctx, userQuery, userID).Scan(&userLanguage, &userLevel)
if err != nil {
return nil, err
}
span.SetAttributes(
attribute.String("user.language", userLanguage),
attribute.String("user.level", userLevel),
)
// Get questions answered by user with stats
answeredQuery := `
SELECT
q.type,
q.level,
COUNT(*) as total,
SUM(CASE WHEN ur.is_correct THEN 1 ELSE 0 END) as correct
FROM user_responses ur
JOIN questions q ON ur.question_id = q.id
WHERE ur.user_id = $1
GROUP BY q.type, q.level
`
rows, err := s.db.QueryContext(ctx, answeredQuery, userID)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
for rows.Next() {
var qType, level string
var total, correct int
if err := rows.Scan(&qType, &level, &total, &correct); err != nil {
return nil, err
}
stats.AnsweredByType[qType] += total
stats.AnsweredByLevel[level] += total
stats.TotalAnswered += total
// Calculate accuracy rates
accuracy := float64(correct) / float64(total) * 100
// For type accuracy, we need to aggregate across levels
if _, exists := stats.AnsweredByType[qType]; exists {
// Recalculate accuracy for this type
typeQuery := `
SELECT
COUNT(*) as total,
SUM(CASE WHEN ur.is_correct THEN 1 ELSE 0 END) as correct
FROM user_responses ur
JOIN questions q ON ur.question_id = q.id
WHERE ur.user_id = $1 AND q.type = $2
`
var typeTotal, typeCorrect int
if err := s.db.QueryRowContext(ctx, typeQuery, userID, qType).Scan(&typeTotal, &typeCorrect); err != nil {
s.logger.Warn(ctx, "Failed to scan type query result", map[string]interface{}{"error": err.Error()})
}
if typeTotal > 0 {
stats.AccuracyByType[qType] = float64(typeCorrect) / float64(typeTotal) * 100
}
} else {
stats.AccuracyByType[qType] = accuracy
}
// For level accuracy
if _, exists := stats.AnsweredByLevel[level]; exists {
// Recalculate accuracy for this level
levelQuery := `
SELECT
COUNT(*) as total,
SUM(CASE WHEN ur.is_correct THEN 1 ELSE 0 END) as correct
FROM user_responses ur
JOIN questions q ON ur.question_id = q.id
WHERE ur.user_id = $1 AND q.level = $2
`
var levelTotal, levelCorrect int
if err := s.db.QueryRowContext(ctx, levelQuery, userID, level).Scan(&levelTotal, &levelCorrect); err != nil {
s.logger.Warn(ctx, "Failed to scan level query result", map[string]interface{}{"error": err.Error()})
}
if levelTotal > 0 {
stats.AccuracyByLevel[level] = float64(levelCorrect) / float64(levelTotal) * 100
}
} else {
stats.AccuracyByLevel[level] = accuracy
}
}
// Get available questions (not answered by user) that belong to this user
availableQuery := `
SELECT
q.type,
q.level,
COUNT(*) as available
FROM questions q
JOIN user_questions uq ON uq.question_id = q.id
WHERE uq.user_id = $1
AND q.language = $2
AND q.status = 'active'
AND q.id NOT IN (
SELECT DISTINCT question_id
FROM user_responses
WHERE user_id = $3
)
GROUP BY q.type, q.level
`
rows, err = s.db.QueryContext(ctx, availableQuery, userID, userLanguage, userID)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
for rows.Next() {
var qType, level string
var available int
if err := rows.Scan(&qType, &level, &available); err != nil {
return nil, err
}
stats.AvailableByType[qType] += available
stats.AvailableByLevel[level] += available
}
// Get recently answered questions (within last hour)
recentQuery := `
SELECT COUNT(*)
FROM user_responses ur
WHERE ur.user_id = $1
AND ur.created_at > NOW() - INTERVAL '1 hour'
`
err = s.db.QueryRowContext(ctx, recentQuery, userID).Scan(&stats.RecentlyAnswered)
if err != nil {
stats.RecentlyAnswered = 0 // Default to 0 if query fails
}
// Calculate overall correct/incorrect answers and accuracy rate
overallQuery := `
SELECT
COUNT(*) as total,
SUM(CASE WHEN is_correct THEN 1 ELSE 0 END) as correct
FROM user_responses
WHERE user_id = $1
`
var total, correct int
err = s.db.QueryRowContext(ctx, overallQuery, userID).Scan(&total, &correct)
if err != nil {
// Default values if query fails
stats.CorrectAnswers = 0
stats.IncorrectAnswers = 0
stats.AccuracyRate = 0.0
} else {
stats.CorrectAnswers = correct
stats.IncorrectAnswers = total - correct
if total > 0 {
stats.AccuracyRate = float64(correct) / float64(total) * 100
} else {
stats.AccuracyRate = 0.0
}
}
return stats, nil
}
// PRIORITY SYSTEM METHODS
// RecordAnswerWithPriority records a user's response and updates priority scores
func (s *LearningService) RecordAnswerWithPriority(ctx context.Context, userID, questionID, answerIndex int, isCorrect bool, responseTime int) error {
// Create user response object
response := &models.UserResponse{
UserID: userID,
QuestionID: questionID,
UserAnswerIndex: answerIndex,
IsCorrect: isCorrect,
ResponseTimeMs: responseTime,
CreatedAt: time.Now(),
}
// Use existing RecordUserResponse method
err := s.RecordUserResponse(ctx, response)
if err != nil {
return contextutils.WrapError(err, "failed to record user response")
}
// Update priority score in background
go s.updatePriorityScoreAsync(ctx, userID, questionID)
return nil
}
// RecordAnswerWithPriorityReturningID records a user's response, updates priority async, and returns the new user_responses ID
func (s *LearningService) RecordAnswerWithPriorityReturningID(ctx context.Context, userID, questionID, answerIndex int, isCorrect bool, responseTime int) (int, error) {
response := &models.UserResponse{
UserID: userID,
QuestionID: questionID,
UserAnswerIndex: answerIndex,
IsCorrect: isCorrect,
ResponseTimeMs: responseTime,
CreatedAt: time.Now(),
}
// Insert and get ID
if err := s.RecordUserResponse(ctx, response); err != nil {
return 0, contextutils.WrapError(err, "failed to record user response")
}
// Update priority score in background
go s.updatePriorityScoreAsync(ctx, userID, questionID)
return response.ID, nil
}
// MarkQuestionAsKnown marks a question as known for a user with optional confidence level
func (s *LearningService) MarkQuestionAsKnown(ctx context.Context, userID, questionID int, confidenceLevel *int) (err error) {
ctx, span := observability.TraceLearningFunction(ctx, "mark_question_as_known",
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
defer observability.FinishSpan(span, &err)
// DEBUG: Log the attempt
s.logger.Debug(ctx, "MarkQuestionAsKnown called", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
})
// Update user_question_metadata table with confidence level
_, err = s.db.ExecContext(ctx, `
INSERT INTO user_question_metadata (user_id, question_id, marked_as_known, marked_as_known_at, confidence_level, created_at, updated_at)
VALUES ($1, $2, TRUE, NOW(), $3, NOW(), NOW())
ON CONFLICT (user_id, question_id) DO UPDATE
SET marked_as_known = TRUE, marked_as_known_at = NOW(), confidence_level = $3, updated_at = NOW()
`, userID, questionID, confidenceLevel)
if err != nil {
// DEBUG: Log the actual error
s.logger.Debug(ctx, "MarkQuestionAsKnown error", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"error": err.Error(),
"error_type": fmt.Sprintf("%T", err),
})
if isForeignKeyConstraintViolation(err) {
s.logger.Debug(ctx, "Foreign key constraint violation detected", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
})
return contextutils.ErrQuestionNotFound
}
s.logger.Debug(ctx, "Not a foreign key constraint violation, returning original error", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
})
return err
}
s.logger.Debug(ctx, "MarkQuestionAsKnown succeeded", map[string]interface{}{
"user_id": userID,
"question_id": questionID,
})
// Update priority score in background so the new confidence affects selection immediately
go s.updatePriorityScoreAsync(ctx, userID, questionID)
return nil
}
// GetUserLearningPreferences retrieves user learning preferences
func (s *LearningService) GetUserLearningPreferences(ctx context.Context, userID int) (result0 *models.UserLearningPreferences, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_learning_preferences",
observability.AttributeUserID(userID),
)
defer observability.FinishSpan(span, &err)
var prefs models.UserLearningPreferences
err = s.db.QueryRowContext(ctx, `
SELECT id, user_id, focus_on_weak_areas, include_review_questions, fresh_question_ratio,
known_question_penalty, review_interval_days, weak_area_boost, daily_reminder_enabled,
tts_voice, last_daily_reminder_sent, daily_goal, created_at, updated_at
FROM user_learning_preferences
WHERE user_id = $1
`, userID).Scan(
&prefs.ID, &prefs.UserID, &prefs.FocusOnWeakAreas, &prefs.IncludeReviewQuestions,
&prefs.FreshQuestionRatio, &prefs.KnownQuestionPenalty, &prefs.ReviewIntervalDays,
&prefs.WeakAreaBoost, &prefs.DailyReminderEnabled,
&prefs.TTSVoice,
&prefs.LastDailyReminderSent,
&prefs.DailyGoal,
&prefs.CreatedAt, &prefs.UpdatedAt,
)
if err == sql.ErrNoRows {
// Check if user exists before creating default preferences
var userExists bool
err = s.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM users WHERE id = $1)", userID).Scan(&userExists)
if err != nil {
return nil, contextutils.WrapError(err, "failed to check if user exists")
}
if !userExists {
return nil, contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "user %d not found", userID)
}
// Create default preferences if none exist
return s.createDefaultPreferences(ctx, userID)
}
if err != nil {
return nil, contextutils.WrapError(err, "failed to get user preferences")
}
return &prefs, nil
}
// UpdateLastDailyReminderSent updates the last daily reminder sent timestamp for a user
func (s *LearningService) UpdateLastDailyReminderSent(ctx context.Context, userID int) (err error) {
ctx, span := observability.TraceLearningFunction(ctx, "update_last_daily_reminder_sent",
observability.AttributeUserID(userID),
)
defer observability.FinishSpan(span, &err)
// Use INSERT ... ON CONFLICT to create the record if it doesn't exist
_, err = s.db.ExecContext(ctx, `
INSERT INTO user_learning_preferences (user_id, last_daily_reminder_sent, updated_at)
VALUES ($1, NOW(), NOW())
ON CONFLICT (user_id) DO UPDATE SET
last_daily_reminder_sent = NOW(),
updated_at = NOW()
`, userID)
if err != nil {
return contextutils.WrapError(err, "failed to update last daily reminder sent")
}
return nil
}
// UpdateUserLearningPreferences updates user learning preferences
func (s *LearningService) UpdateUserLearningPreferences(ctx context.Context, userID int, prefs *models.UserLearningPreferences) (result0 *models.UserLearningPreferences, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "update_user_learning_preferences",
observability.AttributeUserID(userID),
attribute.Bool("prefs.focus_on_weak_areas", prefs.FocusOnWeakAreas),
attribute.Bool("prefs.include_review_questions", prefs.IncludeReviewQuestions),
attribute.Float64("prefs.fresh_question_ratio", prefs.FreshQuestionRatio),
attribute.Float64("prefs.known_question_penalty", prefs.KnownQuestionPenalty),
attribute.Int("prefs.review_interval_days", prefs.ReviewIntervalDays),
attribute.Float64("prefs.weak_area_boost", prefs.WeakAreaBoost),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
var updatedPrefs models.UserLearningPreferences
err = s.db.QueryRowContext(ctx, `
UPDATE user_learning_preferences
SET focus_on_weak_areas = $2, include_review_questions = $3, fresh_question_ratio = $4,
known_question_penalty = $5, review_interval_days = $6, weak_area_boost = $7,
daily_reminder_enabled = $8, tts_voice = $9, daily_goal = COALESCE(NULLIF($10, 0), daily_goal), updated_at = NOW()
WHERE user_id = $1
RETURNING id, user_id, focus_on_weak_areas, include_review_questions, fresh_question_ratio,
known_question_penalty, review_interval_days, weak_area_boost, daily_reminder_enabled,
tts_voice, last_daily_reminder_sent, daily_goal, created_at, updated_at
`, userID, prefs.FocusOnWeakAreas, prefs.IncludeReviewQuestions, prefs.FreshQuestionRatio,
prefs.KnownQuestionPenalty, prefs.ReviewIntervalDays, prefs.WeakAreaBoost, prefs.DailyReminderEnabled, prefs.TTSVoice, prefs.DailyGoal).Scan(
&updatedPrefs.ID, &updatedPrefs.UserID, &updatedPrefs.FocusOnWeakAreas, &updatedPrefs.IncludeReviewQuestions,
&updatedPrefs.FreshQuestionRatio, &updatedPrefs.KnownQuestionPenalty, &updatedPrefs.ReviewIntervalDays,
&updatedPrefs.WeakAreaBoost, &updatedPrefs.DailyReminderEnabled, &updatedPrefs.TTSVoice, &updatedPrefs.LastDailyReminderSent,
&updatedPrefs.DailyGoal, &updatedPrefs.CreatedAt, &updatedPrefs.UpdatedAt,
)
if err == sql.ErrNoRows {
// If no preferences exist, create them with the provided values
return s.createPreferencesWithValues(ctx, userID, prefs)
}
if err != nil {
return nil, contextutils.WrapError(err, "failed to update user preferences")
}
return &updatedPrefs, nil
}
// createPreferencesWithValues creates learning preferences for a user with the provided values
func (s *LearningService) createPreferencesWithValues(ctx context.Context, userID int, prefs *models.UserLearningPreferences) (result0 *models.UserLearningPreferences, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "create_preferences_with_values",
observability.AttributeUserID(userID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Use the provided values, falling back to defaults for any missing fields
defaultPrefs := s.GetDefaultLearningPreferences()
prefs.UserID = userID
// Merge provided values with defaults
if prefs.FocusOnWeakAreas == defaultPrefs.FocusOnWeakAreas && !prefs.FocusOnWeakAreas {
prefs.FocusOnWeakAreas = defaultPrefs.FocusOnWeakAreas
}
if prefs.IncludeReviewQuestions == defaultPrefs.IncludeReviewQuestions && !prefs.IncludeReviewQuestions {
prefs.IncludeReviewQuestions = defaultPrefs.IncludeReviewQuestions
}
if prefs.FreshQuestionRatio == 0 {
prefs.FreshQuestionRatio = defaultPrefs.FreshQuestionRatio
}
if prefs.KnownQuestionPenalty == 0 {
prefs.KnownQuestionPenalty = defaultPrefs.KnownQuestionPenalty
}
if prefs.ReviewIntervalDays == 0 {
prefs.ReviewIntervalDays = defaultPrefs.ReviewIntervalDays
}
if prefs.WeakAreaBoost == 0 {
prefs.WeakAreaBoost = defaultPrefs.WeakAreaBoost
}
if prefs.DailyGoal == 0 {
prefs.DailyGoal = defaultPrefs.DailyGoal
}
// Try to insert with ON CONFLICT DO NOTHING to handle race conditions
_, err = s.db.ExecContext(ctx, `
INSERT INTO user_learning_preferences (user_id, focus_on_weak_areas, include_review_questions,
fresh_question_ratio, known_question_penalty,
review_interval_days, weak_area_boost, daily_reminder_enabled,
tts_voice, daily_goal, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW(), NOW())
ON CONFLICT (user_id) DO NOTHING
`, userID, prefs.FocusOnWeakAreas, prefs.IncludeReviewQuestions,
prefs.FreshQuestionRatio, prefs.KnownQuestionPenalty,
prefs.ReviewIntervalDays, prefs.WeakAreaBoost, prefs.DailyReminderEnabled, prefs.TTSVoice, prefs.DailyGoal)
if err != nil {
return nil, contextutils.WrapError(err, "failed to create preferences with values")
}
// Now fetch the preferences (either the ones we just created or the ones created by another concurrent request)
err = s.db.QueryRowContext(ctx, `
SELECT id, user_id, focus_on_weak_areas, include_review_questions, fresh_question_ratio,
known_question_penalty, review_interval_days, weak_area_boost, daily_reminder_enabled,
tts_voice, last_daily_reminder_sent, daily_goal, created_at, updated_at
FROM user_learning_preferences
WHERE user_id = $1
`, userID).Scan(
&prefs.ID, &prefs.UserID, &prefs.FocusOnWeakAreas, &prefs.IncludeReviewQuestions,
&prefs.FreshQuestionRatio, &prefs.KnownQuestionPenalty, &prefs.ReviewIntervalDays,
&prefs.WeakAreaBoost, &prefs.DailyReminderEnabled, &prefs.TTSVoice, &prefs.LastDailyReminderSent,
&prefs.DailyGoal, &prefs.CreatedAt, &prefs.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to fetch created preferences")
}
return prefs, nil
}
// createDefaultPreferences creates default learning preferences for a user
func (s *LearningService) createDefaultPreferences(ctx context.Context, userID int) (result0 *models.UserLearningPreferences, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "create_default_preferences",
observability.AttributeUserID(userID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
defaultPrefs := s.GetDefaultLearningPreferences()
defaultPrefs.UserID = userID
// Try to insert with ON CONFLICT DO NOTHING to handle race conditions
_, err = s.db.ExecContext(ctx, `
INSERT INTO user_learning_preferences (user_id, focus_on_weak_areas, include_review_questions,
fresh_question_ratio, known_question_penalty,
review_interval_days, weak_area_boost, daily_reminder_enabled,
tts_voice, daily_goal, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW(), NOW())
ON CONFLICT (user_id) DO NOTHING
`, userID, defaultPrefs.FocusOnWeakAreas, defaultPrefs.IncludeReviewQuestions,
defaultPrefs.FreshQuestionRatio, defaultPrefs.KnownQuestionPenalty,
defaultPrefs.ReviewIntervalDays, defaultPrefs.WeakAreaBoost, defaultPrefs.DailyReminderEnabled, defaultPrefs.TTSVoice, defaultPrefs.DailyGoal)
if err != nil {
return nil, contextutils.WrapError(err, "failed to create default preferences")
}
// Now fetch the preferences (either the ones we just created or the ones created by another concurrent request)
err = s.db.QueryRowContext(ctx, `
SELECT id, user_id, focus_on_weak_areas, include_review_questions, fresh_question_ratio,
known_question_penalty, review_interval_days, weak_area_boost, daily_reminder_enabled,
tts_voice, last_daily_reminder_sent, daily_goal, created_at, updated_at
FROM user_learning_preferences
WHERE user_id = $1
`, userID).Scan(
&defaultPrefs.ID, &defaultPrefs.UserID, &defaultPrefs.FocusOnWeakAreas, &defaultPrefs.IncludeReviewQuestions,
&defaultPrefs.FreshQuestionRatio, &defaultPrefs.KnownQuestionPenalty, &defaultPrefs.ReviewIntervalDays,
&defaultPrefs.WeakAreaBoost, &defaultPrefs.DailyReminderEnabled, &defaultPrefs.TTSVoice, &defaultPrefs.LastDailyReminderSent,
&defaultPrefs.DailyGoal, &defaultPrefs.CreatedAt, &defaultPrefs.UpdatedAt,
)
if err != nil {
return nil, contextutils.WrapError(err, "failed to fetch created preferences")
}
return defaultPrefs, nil
}
// GetDefaultLearningPreferences returns default learning preferences
func (s *LearningService) GetDefaultLearningPreferences() *models.UserLearningPreferences {
return &models.UserLearningPreferences{
FocusOnWeakAreas: true,
IncludeReviewQuestions: true,
FreshQuestionRatio: 0.3,
KnownQuestionPenalty: 0.1,
ReviewIntervalDays: 7,
WeakAreaBoost: 2.0,
DailyReminderEnabled: false, // Default to false for daily reminders
DailyGoal: 10,
TTSVoice: "",
}
}
// CalculatePriorityScore calculates priority score for a specific question for a user
func (s *LearningService) CalculatePriorityScore(ctx context.Context, userID, questionID int) (result0 float64, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "calculate_priority_score",
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Get user preferences
prefs, err := s.GetUserLearningPreferences(ctx, userID)
if err != nil {
return 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get user preferences: %w", err)
}
// Get user's performance history for this question
performance, err := s.getQuestionPerformance(ctx, userID, questionID)
if err != nil {
return 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get question performance: %w", err)
}
// Calculate components
baseScore := 100.0
performanceMultiplier := s.calculatePerformanceMultiplier(performance, prefs.WeakAreaBoost)
spacedRepetitionBoost := s.calculateSpacedRepetitionBoost(performance.LastSeenAt)
userPreferenceMultiplier := s.calculateUserPreferenceMultiplier(performance, prefs)
freshnessBoost := s.calculateFreshnessBoost(performance.TimesAnswered)
// Final score with bounds checking
finalScore := baseScore * performanceMultiplier * spacedRepetitionBoost * userPreferenceMultiplier * freshnessBoost
// Apply bounds to prevent extreme values
if finalScore < 1.0 {
finalScore = 1.0
} else if finalScore > 1000.0 {
finalScore = 1000.0
}
return finalScore, nil
}
// updatePriorityScoreAsync updates priority score for a question asynchronously
func (s *LearningService) updatePriorityScoreAsync(ctx context.Context, userID, questionID int) {
ctx, span := observability.TraceLearningFunction(ctx, "update_priority_score_async",
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
defer span.End()
score, err := s.CalculatePriorityScore(ctx, userID, questionID)
if err != nil {
s.logger.Error(ctx, "Failed to calculate priority score", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
})
return
}
// Update or insert priority score
_, err = s.db.ExecContext(ctx, `
INSERT INTO question_priority_scores (user_id, question_id, priority_score, last_calculated_at, created_at, updated_at)
VALUES ($1, $2, $3, NOW(), NOW(), NOW())
ON CONFLICT (user_id, question_id) DO UPDATE
SET priority_score = $3, last_calculated_at = NOW(), updated_at = NOW()
`, userID, questionID, score)
if err != nil {
s.logger.Error(ctx, "Failed to update priority score", err, map[string]interface{}{
"user_id": userID,
"question_id": questionID,
"score": score,
})
}
}
// QuestionPerformance represents performance data for a specific question
type QuestionPerformance struct {
TimesAnswered int
CorrectAnswers int
LastSeenAt *time.Time
MarkedAsKnown bool
MarkedAsKnownAt *time.Time
ConfidenceLevel *int
}
// getQuestionPerformance retrieves performance data for a specific question
func (s *LearningService) getQuestionPerformance(ctx context.Context, userID, questionID int) (result0 *QuestionPerformance, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_question_performance",
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
performance := &QuestionPerformance{}
// Get response statistics
err = s.db.QueryRowContext(ctx, `
SELECT
COUNT(*) as times_answered,
COALESCE(SUM(CASE WHEN is_correct THEN 1 ELSE 0 END), 0) as correct_answers,
MAX(created_at) as last_seen_at
FROM user_responses
WHERE user_id = $1 AND question_id = $2
`, userID, questionID).Scan(
&performance.TimesAnswered,
&performance.CorrectAnswers,
&performance.LastSeenAt,
)
if err != nil && err != sql.ErrNoRows {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get response statistics: %w", err)
}
// Get metadata
var markedAsKnownAt sql.NullTime
var confidenceLevel sql.NullInt32
err = s.db.QueryRowContext(ctx, `
SELECT marked_as_known, marked_as_known_at, confidence_level
FROM user_question_metadata
WHERE user_id = $1 AND question_id = $2
`, userID, questionID).Scan(&performance.MarkedAsKnown, &markedAsKnownAt, &confidenceLevel)
if err != nil && err != sql.ErrNoRows {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get question metadata: %w", err)
}
if markedAsKnownAt.Valid {
performance.MarkedAsKnownAt = &markedAsKnownAt.Time
}
if confidenceLevel.Valid {
level := int(confidenceLevel.Int32)
performance.ConfidenceLevel = &level
}
return performance, nil
}
// calculatePerformanceMultiplier calculates the performance-based multiplier
func (s *LearningService) calculatePerformanceMultiplier(performance *QuestionPerformance, weakAreaBoost float64) float64 {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
if performance.TimesAnswered == 0 {
return 1.0 // Neutral for new questions
}
errorRate := float64(performance.TimesAnswered-performance.CorrectAnswers) / float64(performance.TimesAnswered)
successRate := float64(performance.CorrectAnswers) / float64(performance.TimesAnswered)
// Apply weak area boost for questions with high error rates
multiplier := 1.0 + (errorRate * weakAreaBoost) - (successRate * 0.5)
// Apply bounds to prevent extreme values
if multiplier < 0.1 {
multiplier = 0.1
} else if multiplier > 10.0 {
multiplier = 10.0
}
return multiplier
}
// calculateSpacedRepetitionBoost calculates the spaced repetition boost
func (s *LearningService) calculateSpacedRepetitionBoost(lastSeenAt *time.Time) float64 {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
if lastSeenAt == nil {
return 1.0 // No boost for never-seen questions
}
daysSinceLastSeen := time.Since(*lastSeenAt).Hours() / 24.0
boost := 1.0 + (daysSinceLastSeen * 0.1)
// Cap the boost at 5.0x multiplier
return math.Min(boost, 5.0)
}
// calculateUserPreferenceMultiplier calculates how user preference ("mark known" with confidence)
// influences question priority.
//
// New policy:
// - Confidence 1â2: show MORE (boost priority) â multipliers > 1
// - Confidence 3: neutral â multiplier = 1
// - Confidence 4â5: show LESS (reduce priority) â multiplier < 1 using KnownQuestionPenalty
func (s *LearningService) calculateUserPreferenceMultiplier(performance *QuestionPerformance, prefs *models.UserLearningPreferences) float64 {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
if performance.MarkedAsKnown {
if performance.ConfidenceLevel != nil {
switch *performance.ConfidenceLevel {
case 1:
// Low confidence â increase frequency noticeably
return 1.25
case 2:
// Some confidence â slight increase in frequency
return 1.10
case 3:
// Neutral â no change
return 1.0
case 4:
// Very confident â decrease frequency using half of penalty
return prefs.KnownQuestionPenalty * 0.5
case 5:
// Extremely confident â strong decrease using 10% of penalty
return prefs.KnownQuestionPenalty * 0.1
default:
return 1.0
}
}
// Fallback when confidence not provided â use configured penalty
return prefs.KnownQuestionPenalty
}
return 1.0
}
// calculateFreshnessBoost calculates the freshness boost for new questions
func (s *LearningService) calculateFreshnessBoost(timesAnswered int) float64 {
// Note: This is a pure function that doesn't need tracing since it doesn't make external calls
if timesAnswered == 0 {
return 1.5 // Boost for fresh questions
}
return 1.0
}
// isForeignKeyConstraintViolation checks if the error is a foreign key constraint violation
func isForeignKeyConstraintViolation(err error) bool {
if err == nil {
return false
}
// Check for PostgreSQL foreign key constraint violation error code
if pqErr, ok := err.(*pq.Error); ok {
// PostgreSQL error code 23503 is for foreign key constraint violations
if pqErr.Code == "23503" {
return true
}
}
// Also check for the error message pattern as a fallback
errorStr := err.Error()
return strings.Contains(errorStr, "violates foreign key constraint")
}
// Analytics Methods
// GetPriorityScoreDistribution returns the distribution of priority scores
func (s *LearningService) GetPriorityScoreDistribution(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_priority_score_distribution")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
COUNT(CASE WHEN qps.priority_score > 200 THEN 1 END) as high,
COUNT(CASE WHEN qps.priority_score BETWEEN 100 AND 200 THEN 1 END) as medium,
COUNT(CASE WHEN qps.priority_score < 100 THEN 1 END) as low,
AVG(qps.priority_score) as average
FROM question_priority_scores qps
JOIN questions q ON qps.question_id = q.id
WHERE qps.priority_score > 0
`
var high, medium, low int
var average sql.NullFloat64
err = s.db.QueryRowContext(ctx, query).Scan(&high, &medium, &low, &average)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get priority score distribution: %w", err)
}
result := map[string]interface{}{
"high": high,
"medium": medium,
"low": low,
"average": 0.0,
}
if average.Valid {
result["average"] = average.Float64
}
span.SetAttributes(
attribute.Int("high_count", high),
attribute.Int("medium_count", medium),
attribute.Int("low_count", low),
attribute.Float64("average_score", result["average"].(float64)),
)
return result, nil
}
// GetHighPriorityQuestions returns the highest priority questions
func (s *LearningService) GetHighPriorityQuestions(ctx context.Context, limit int) (result0 []map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_high_priority_questions",
attribute.Int("limit", limit),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
q.type as question_type,
q.level,
q.topic_category as topic,
qps.priority_score
FROM question_priority_scores qps
JOIN questions q ON qps.question_id = q.id
WHERE qps.priority_score > 200
ORDER BY qps.priority_score DESC
LIMIT $1
`
rows, err := s.db.QueryContext(ctx, query, limit)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get high priority questions: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []map[string]interface{}
for rows.Next() {
var questionType, level, topic sql.NullString
var priorityScore float64
err = rows.Scan(&questionType, &level, &topic, &priorityScore)
if err != nil {
continue
}
question := map[string]interface{}{
"question_type": questionType.String,
"level": level.String,
"topic": topic.String,
"priority_score": priorityScore,
}
questions = append(questions, question)
}
span.SetAttributes(attribute.Int("questions_count", len(questions)))
return questions, nil
}
// GetWeakAreasByTopic returns weak areas by topic
func (s *LearningService) GetWeakAreasByTopic(ctx context.Context, limit int) (result0 []map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_weak_areas_by_topic",
attribute.Int("limit", limit),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
topic,
SUM(total_attempts) as total_attempts,
SUM(correct_attempts) as correct_attempts
FROM performance_metrics
WHERE total_attempts > 0
GROUP BY topic
ORDER BY (SUM(correct_attempts)::float / SUM(total_attempts)) ASC
LIMIT $1
`
rows, err := s.db.QueryContext(ctx, query, limit)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get weak areas: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var weakAreas []map[string]interface{}
for rows.Next() {
var topic sql.NullString
var totalAttempts, correctAttempts int
err = rows.Scan(&topic, &totalAttempts, &correctAttempts)
if err != nil {
continue
}
area := map[string]interface{}{
"topic": topic.String,
"total_attempts": totalAttempts,
"correct_attempts": correctAttempts,
}
weakAreas = append(weakAreas, area)
}
span.SetAttributes(attribute.Int("weak_areas_count", len(weakAreas)))
return weakAreas, nil
}
// GetLearningPreferencesUsage returns learning preferences usage statistics
func (s *LearningService) GetLearningPreferencesUsage(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_learning_preferences_usage")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
COUNT(*) as total_users,
AVG(focus_on_weak_areas::int) as avg_focus_on_weak_areas,
AVG(fresh_question_ratio) as avg_fresh_question_ratio,
AVG(weak_area_boost) as avg_weak_area_boost,
AVG(known_question_penalty) as avg_known_question_penalty
FROM user_learning_preferences
`
var totalUsers int
var avgFocusOnWeakAreas, avgFreshQuestionRatio, avgWeakAreaBoost, avgKnownQuestionPenalty sql.NullFloat64
err = s.db.QueryRowContext(ctx, query).Scan(
&totalUsers,
&avgFocusOnWeakAreas,
&avgFreshQuestionRatio,
&avgWeakAreaBoost,
&avgKnownQuestionPenalty,
)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get learning preferences usage: %w", err)
}
result := map[string]interface{}{
"total_users": 0,
"focusOnWeakAreas": false,
"freshQuestionRatio": 0.3,
"weakAreaBoost": 2.0,
"knownQuestionPenalty": 0.1,
}
if totalUsers > 0 {
result["total_users"] = totalUsers
if avgFocusOnWeakAreas.Valid {
result["focusOnWeakAreas"] = avgFocusOnWeakAreas.Float64 > 0.5
}
if avgFreshQuestionRatio.Valid {
result["freshQuestionRatio"] = avgFreshQuestionRatio.Float64
}
if avgWeakAreaBoost.Valid {
result["weakAreaBoost"] = avgWeakAreaBoost.Float64
}
if avgKnownQuestionPenalty.Valid {
result["knownQuestionPenalty"] = avgKnownQuestionPenalty.Float64
}
}
span.SetAttributes(
attribute.Int("total_users", result["total_users"].(int)),
attribute.Bool("focus_on_weak_areas", result["focusOnWeakAreas"].(bool)),
attribute.Float64("fresh_question_ratio", result["freshQuestionRatio"].(float64)),
attribute.Float64("weak_area_boost", result["weakAreaBoost"].(float64)),
attribute.Float64("known_question_penalty", result["knownQuestionPenalty"].(float64)),
)
return result, nil
}
// GetQuestionTypeGaps returns gaps in question types
func (s *LearningService) GetQuestionTypeGaps(ctx context.Context) (result0 []map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_question_type_gaps")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
q.type as question_type,
q.level,
COUNT(q.id) as available,
COUNT(qps.question_id) as with_priority_scores
FROM questions q
LEFT JOIN question_priority_scores qps ON q.id = qps.question_id
GROUP BY q.type, q.level
HAVING COUNT(qps.question_id) < COUNT(q.id) * 0.8
ORDER BY (COUNT(qps.question_id)::float / COUNT(q.id)) ASC
`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
span.SetAttributes(attribute.String("error.type", "database_query_failed"), attribute.String("error", err.Error()))
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get question type gaps: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows in GetQuestionTypeGaps", map[string]interface{}{"error": err.Error()})
}
}()
var gaps []map[string]interface{}
var scanErrors int
for rows.Next() {
var questionType, level sql.NullString
var available, withPriorityScores int
err = rows.Scan(&questionType, &level, &available, &withPriorityScores)
if err != nil {
scanErrors++
span.SetAttributes(attribute.String("error.type", "row_scan_failed"), attribute.String("error", err.Error()))
continue
}
gap := map[string]interface{}{
"question_type": questionType.String,
"level": level.String,
"available": available,
"demand": available - withPriorityScores,
}
gaps = append(gaps, gap)
}
if err := rows.Err(); err != nil {
span.SetAttributes(attribute.String("error.type", "rows_iteration_failed"), attribute.String("error", err.Error()))
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "error during rows iteration: %w", err)
}
span.SetAttributes(
attribute.Int("gaps_count", len(gaps)),
attribute.Int("scan_errors", scanErrors),
)
return gaps, nil
}
// GetGenerationSuggestions returns suggestions for question generation
func (s *LearningService) GetGenerationSuggestions(ctx context.Context) (result0 []map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_generation_suggestions")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
q.type as question_type,
q.level,
q.language,
COUNT(q.id) as available,
COUNT(CASE WHEN qps.priority_score > 100 THEN 1 END) as high_priority,
AVG(qps.priority_score) as avg_priority
FROM questions q
LEFT JOIN question_priority_scores qps ON q.id = qps.question_id
GROUP BY q.type, q.level, q.language
HAVING COUNT(q.id) < 50 OR COUNT(CASE WHEN qps.priority_score > 100 THEN 1 END) < 10
ORDER BY COUNT(q.id) ASC, AVG(qps.priority_score) DESC
`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
span.SetAttributes(attribute.String("error.type", "database_query_failed"), attribute.String("error", err.Error()))
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get generation suggestions: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows in GetGenerationSuggestions", map[string]interface{}{"error": err.Error()})
}
}()
var suggestions []map[string]interface{}
var scanErrors int
for rows.Next() {
var questionType, level, language sql.NullString
var available, highPriority int
var avgPriority sql.NullFloat64
err = rows.Scan(&questionType, &level, &language, &available, &highPriority, &avgPriority)
if err != nil {
scanErrors++
span.SetAttributes(attribute.String("error.type", "row_scan_failed"), attribute.String("error", err.Error()))
continue
}
suggestion := map[string]interface{}{
"question_type": questionType.String,
"level": level.String,
"language": language.String,
"available": available,
"high_priority": highPriority,
"avg_priority": 0.0,
"priority_score": 0.0,
}
if avgPriority.Valid {
suggestion["avg_priority"] = avgPriority.Float64
suggestion["priority_score"] = avgPriority.Float64
}
suggestions = append(suggestions, suggestion)
}
if err := rows.Err(); err != nil {
span.SetAttributes(attribute.String("error.type", "rows_iteration_failed"), attribute.String("error", err.Error()))
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "error during rows iteration: %w", err)
}
span.SetAttributes(
attribute.Int("suggestions_count", len(suggestions)),
attribute.Int("scan_errors", scanErrors),
)
return suggestions, nil
}
// GetPrioritySystemPerformance returns performance metrics for the priority system
func (s *LearningService) GetPrioritySystemPerformance(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_priority_system_performance")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// This is a simplified implementation - in a real system, this would track actual performance metrics
query := `
SELECT
COUNT(*) as total_calculations,
AVG(priority_score) as avg_score,
MAX(last_calculated_at) as last_calculation
FROM question_priority_scores
WHERE last_calculated_at > NOW() - INTERVAL '1 hour'
`
var totalCalculations int
var avgScore sql.NullFloat64
var lastCalculation sql.NullTime
err = s.db.QueryRowContext(ctx, query).Scan(&totalCalculations, &avgScore, &lastCalculation)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get priority system performance: %w", err)
}
result := map[string]interface{}{
"calculationsPerSecond": float64(totalCalculations) / 3600.0, // Per hour converted to per second
"avgCalculationTime": 0.0, // Would need to track actual calculation times
"avgQueryTime": 0.0, // Would need to track actual query times
"memoryUsage": 0.0, // Would need to track actual memory usage
"avgScore": 0.0, // Default value
}
if avgScore.Valid {
result["avgScore"] = avgScore.Float64
}
if lastCalculation.Valid {
result["lastCalculation"] = lastCalculation.Time.Format(time.RFC3339)
}
span.SetAttributes(
attribute.Float64("calculations_per_second", result["calculationsPerSecond"].(float64)),
attribute.Float64("avg_score", result["avgScore"].(float64)),
attribute.Int("total_calculations", totalCalculations),
)
return result, nil
}
// GetBackgroundJobsStatus returns the status of background jobs
func (s *LearningService) GetBackgroundJobsStatus(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_background_jobs_status")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// This is a simplified implementation - in a real system, this would track actual background job status
query := `
SELECT
COUNT(*) as total_updates,
MAX(updated_at) as last_update
FROM question_priority_scores
WHERE updated_at > NOW() - INTERVAL '1 minute'
`
var totalUpdates int
var lastUpdate sql.NullTime
err = s.db.QueryRowContext(ctx, query).Scan(&totalUpdates, &lastUpdate)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get background jobs status")
}
result := map[string]interface{}{
"priorityUpdates": totalUpdates,
"lastUpdate": "N/A",
"queueSize": 0, // Would need to track actual queue size
"status": "healthy",
}
if lastUpdate.Valid {
result["lastUpdate"] = lastUpdate.Time.Format(time.RFC3339)
}
if totalUpdates == 0 {
result["status"] = "idle"
}
span.SetAttributes(
attribute.Int("priority_updates", totalUpdates),
attribute.String("status", result["status"].(string)),
attribute.Int("queue_size", result["queueSize"].(int)),
)
return result, nil
}
// GetUserPriorityScoreDistribution returns priority score distribution for a specific user
func (s *LearningService) GetUserPriorityScoreDistribution(ctx context.Context, userID int) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_priority_score_distribution",
observability.AttributeUserID(userID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
COUNT(CASE WHEN priority_score > 200 THEN 1 END) as high,
COUNT(CASE WHEN priority_score BETWEEN 100 AND 200 THEN 1 END) as medium,
COUNT(CASE WHEN priority_score < 100 THEN 1 END) as low,
AVG(priority_score) as average
FROM question_priority_scores
WHERE user_id = $1 AND priority_score > 0
`
var high, medium, low int
var average sql.NullFloat64
err = s.db.QueryRowContext(ctx, query, userID).Scan(&high, &medium, &low, &average)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get user priority score distribution")
}
result := map[string]interface{}{
"high": high,
"medium": medium,
"low": low,
"average": 0.0,
}
if average.Valid {
result["average"] = average.Float64
}
span.SetAttributes(
attribute.Int("high_count", high),
attribute.Int("medium_count", medium),
attribute.Int("low_count", low),
attribute.Float64("average_score", result["average"].(float64)),
)
return result, nil
}
// GetUserHighPriorityQuestions returns the highest priority questions for a specific user
func (s *LearningService) GetUserHighPriorityQuestions(ctx context.Context, userID, limit int) (result0 []map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_high_priority_questions",
observability.AttributeUserID(userID),
attribute.Int("limit", limit),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
q.type as question_type,
q.level,
q.topic_category as topic,
qps.priority_score
FROM question_priority_scores qps
JOIN questions q ON qps.question_id = q.id
WHERE qps.user_id = $1 AND qps.priority_score > 200
ORDER BY qps.priority_score DESC
LIMIT $2
`
rows, err := s.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get user high priority questions")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []map[string]interface{}
for rows.Next() {
var questionType, level, topic sql.NullString
var priorityScore float64
err = rows.Scan(&questionType, &level, &topic, &priorityScore)
if err != nil {
continue
}
question := map[string]interface{}{
"question_type": questionType.String,
"level": level.String,
"topic": topic.String,
"priority_score": priorityScore,
}
questions = append(questions, question)
}
span.SetAttributes(attribute.Int("questions_count", len(questions)))
return questions, nil
}
// GetUserWeakAreas returns weak areas for a specific user
func (s *LearningService) GetUserWeakAreas(ctx context.Context, userID, limit int) (result0 []map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_weak_areas",
observability.AttributeUserID(userID),
attribute.Int("limit", limit),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
topic,
total_attempts,
correct_attempts
FROM performance_metrics
WHERE user_id = $1 AND total_attempts > 0
ORDER BY (correct_attempts::float / total_attempts) ASC
LIMIT $2
`
rows, err := s.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get user weak areas")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var weakAreas []map[string]interface{}
for rows.Next() {
var topic sql.NullString
var totalAttempts, correctAttempts int
err = rows.Scan(&topic, &totalAttempts, &correctAttempts)
if err != nil {
continue
}
area := map[string]interface{}{
"topic": topic.String,
"total_attempts": totalAttempts,
"correct_attempts": correctAttempts,
}
weakAreas = append(weakAreas, area)
}
span.SetAttributes(attribute.Int("weak_areas_count", len(weakAreas)))
return weakAreas, nil
}
// Priority generation methods moved to worker
// GetHighPriorityTopics returns topics with high average priority scores for a user
func (s *LearningService) GetHighPriorityTopics(ctx context.Context, userID int) (result0 []string, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_high_priority_topics",
observability.AttributeUserID(userID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT q.topic_category, AVG(qps.priority_score) as avg_score
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
JOIN question_priority_scores qps ON q.id = qps.question_id AND qps.user_id = $1
WHERE uq.user_id = $1
AND q.topic_category IS NOT NULL
AND q.topic_category != ''
GROUP BY q.topic_category
HAVING AVG(qps.priority_score) >= 150.0
ORDER BY avg_score DESC
LIMIT 5
`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get high priority topics")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var topics []string
for rows.Next() {
var topic string
var avgScore float64
if err := rows.Scan(&topic, &avgScore); err != nil {
continue
}
topics = append(topics, topic)
}
span.SetAttributes(attribute.Int("topics_count", len(topics)))
// Ensure we always return a slice, not nil
if topics == nil {
topics = []string{}
}
return topics, nil
}
// GetGapAnalysis identifies areas with poor user performance (knowledge gaps)
func (s *LearningService) GetGapAnalysis(ctx context.Context, userID int) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_gap_analysis",
observability.AttributeUserID(userID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Query to find areas where user has poor performance (low accuracy)
query := `
SELECT
pm.topic,
COUNT(*) as total_questions,
ROUND((pm.correct_attempts * 100.0 / pm.total_attempts), 2) as accuracy_percentage
FROM performance_metrics pm
WHERE pm.user_id = $1
AND pm.total_attempts >= 3
AND (pm.correct_attempts * 100.0 / pm.total_attempts) < 70.0
GROUP BY pm.topic, pm.correct_attempts, pm.total_attempts
ORDER BY accuracy_percentage ASC
LIMIT 10
`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get gap analysis")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
gaps := make(map[string]interface{})
for rows.Next() {
var topic string
var totalQuestions int
var accuracyPercentage sql.NullFloat64
if err := rows.Scan(&topic, &totalQuestions, &accuracyPercentage); err != nil {
continue
}
gapInfo := map[string]interface{}{
"topic": topic,
"total_questions": totalQuestions,
"accuracy_percentage": 0.0,
}
if accuracyPercentage.Valid {
gapInfo["accuracy_percentage"] = accuracyPercentage.Float64
}
gaps[topic] = gapInfo
}
span.SetAttributes(attribute.Int("gaps_count", len(gaps)))
return gaps, nil
}
// GetPriorityDistribution returns the distribution of priority scores by topic for a user
func (s *LearningService) GetPriorityDistribution(ctx context.Context, userID int) (result0 map[string]int, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_priority_distribution",
observability.AttributeUserID(userID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Query to get priority score distribution by topic
query := `
SELECT q.topic_category, COUNT(*) as question_count
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
JOIN question_priority_scores qps ON q.id = qps.question_id AND qps.user_id = $1
WHERE uq.user_id = $1
AND q.topic_category IS NOT NULL
AND q.topic_category != ''
GROUP BY q.topic_category
ORDER BY question_count DESC
`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get priority distribution")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
distribution := make(map[string]int)
for rows.Next() {
var topic string
var count int
if err := rows.Scan(&topic, &count); err != nil {
continue
}
distribution[topic] = count
}
span.SetAttributes(attribute.Int("topics_count", len(distribution)))
return distribution, nil
}
// GetUserQuestionConfidenceLevel retrieves the confidence level for a specific question and user
func (s *LearningService) GetUserQuestionConfidenceLevel(ctx context.Context, userID, questionID int) (result0 *int, err error) {
ctx, span := observability.TraceLearningFunction(ctx, "get_user_question_confidence_level",
observability.AttributeUserID(userID),
observability.AttributeQuestionID(questionID),
)
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT confidence_level
FROM user_question_metadata
WHERE user_id = $1 AND question_id = $2
`
var confidenceLevel sql.NullInt32
err = s.db.QueryRowContext(ctx, query, userID, questionID).Scan(&confidenceLevel)
if err != nil {
if err == sql.ErrNoRows {
// No confidence level recorded for this user-question pair
return nil, nil
}
return nil, contextutils.WrapError(err, "failed to get user question confidence level")
}
if confidenceLevel.Valid {
level := int(confidenceLevel.Int32)
return &level, nil
}
return nil, nil
}
package services
import (
"fmt"
contextutils "quizapp/internal/utils"
)
// NoQuestionsAvailableError is returned when no suitable questions can be found for assignment.
type NoQuestionsAvailableError struct {
Language string
Level string
CandidateIDs []int
CandidateCount int
TotalMatching int
}
func (e *NoQuestionsAvailableError) Error() string {
return fmt.Sprintf("no questions available for assignment (language=%s level=%s candidate_count=%d total_matching=%d)", e.Language, e.Level, e.CandidateCount, e.TotalMatching)
}
// Unwrap allows errors.Is(..., contextutils.ErrNoQuestionsAvailable) to work.
func (e *NoQuestionsAvailableError) Unwrap() error {
return contextutils.ErrNoQuestionsAvailable
}
package services
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
)
// ErrSignupsDisabled is returned when user registration is disabled by config
var ErrSignupsDisabled = errors.New("user registration is currently disabled")
// OAuth sentinel errors
var (
ErrOAuthCodeAlreadyUsed = errors.New("authorization code has already been used")
ErrOAuthClientConfig = errors.New("OAuth client configuration error")
ErrOAuthInvalidRequest = errors.New("invalid OAuth request")
ErrOAuthUnauthorized = errors.New("OAuth client is not authorized")
ErrOAuthUnsupportedGrant = errors.New("unsupported OAuth grant type")
)
// OAuthService handles OAuth authentication flows
type OAuthService struct {
config *config.Config
TokenEndpoint string // for testing/mocking
UserInfoEndpoint string // for testing/mocking
logger *observability.Logger
}
// NewOAuthServiceWithLogger creates a new OAuth service with logger
func NewOAuthServiceWithLogger(cfg *config.Config, logger *observability.Logger) *OAuthService {
return &OAuthService{
config: cfg,
TokenEndpoint: "https://oauth2.googleapis.com/token",
UserInfoEndpoint: "https://www.googleapis.com/oauth2/v2/userinfo",
logger: logger,
}
}
// GoogleUserInfo represents the user information returned by Google OAuth
type GoogleUserInfo struct {
ID string `json:"id"`
Email string `json:"email"`
Name string `json:"name"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
Picture string `json:"picture"`
VerifiedEmail bool `json:"verified_email"`
}
// GoogleTokenResponse represents the token response from Google OAuth
type GoogleTokenResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token,omitempty"`
}
// GetGoogleAuthURL generates the Google OAuth authorization URL
func (s *OAuthService) GetGoogleAuthURL(ctx context.Context, state string) string {
_, span := observability.TraceOAuthFunction(ctx, "get_google_auth_url",
attribute.String("oauth.state", state),
attribute.String("oauth.client_id", s.config.GoogleOAuthClientID),
attribute.String("oauth.redirect_url", s.config.GoogleOAuthRedirectURL),
)
defer span.End()
// Debug logging
if s.config.GoogleOAuthClientID == "" {
if s.logger != nil {
s.logger.Warn(ctx, "Google OAuth client ID is not set", map[string]interface{}{"env_var": "GOOGLE_OAUTH_CLIENT_ID"})
}
}
if s.config.GoogleOAuthRedirectURL == "" {
if s.logger != nil {
s.logger.Warn(ctx, "Google OAuth redirect URL is not set", map[string]interface{}{"env_var": "GOOGLE_OAUTH_REDIRECT_URL"})
}
}
params := url.Values{}
params.Set("client_id", s.config.GoogleOAuthClientID)
params.Set("redirect_uri", s.config.GoogleOAuthRedirectURL)
params.Set("response_type", "code")
params.Set("scope", "openid email profile")
params.Set("state", state)
params.Set("access_type", "offline")
params.Set("prompt", "consent")
return fmt.Sprintf("https://accounts.google.com/o/oauth2/v2/auth?%s", params.Encode())
}
// ExchangeCodeForToken exchanges the authorization code for an access token
func (s *OAuthService) ExchangeCodeForToken(ctx context.Context, code string) (result0 *GoogleTokenResponse, err error) {
ctx, span := observability.TraceOAuthFunction(ctx, "exchange_code_for_token",
attribute.String("oauth.code", code),
attribute.String("oauth.token_endpoint", s.TokenEndpoint),
)
defer observability.FinishSpan(span, &err)
data := url.Values{}
data.Set("client_id", s.config.GoogleOAuthClientID)
data.Set("client_secret", s.config.GoogleOAuthClientSecret)
data.Set("code", code)
data.Set("grant_type", "authorization_code")
data.Set("redirect_uri", s.config.GoogleOAuthRedirectURL)
tokenURL := s.TokenEndpoint
if tokenURL == "" {
tokenURL = "https://oauth2.googleapis.com/token"
}
req, err := http.NewRequest("POST", tokenURL, strings.NewReader(data.Encode()))
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to create token request")
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
// Use instrumented HTTP client for automatic tracing with explicit span options
client := &http.Client{
Timeout: config.OAuthHTTPTimeout,
Transport: otelhttp.NewTransport(http.DefaultTransport,
otelhttp.WithSpanOptions(trace.WithSpanKind(trace.SpanKindClient)),
),
}
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to exchange code for token")
}
defer func() {
cerr := resp.Body.Close()
if cerr != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{"error": cerr.Error()})
}
}()
span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode))
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
// Try to parse the error response for better error messages
var errorResp struct {
Error string `json:"error"`
ErrorDescription string `json:"error_description"`
}
if json.Unmarshal(body, &errorResp) == nil {
span.SetAttributes(
attribute.String("oauth.error", errorResp.Error),
attribute.String("oauth.error_description", errorResp.ErrorDescription),
)
switch errorResp.Error {
case "invalid_grant":
return nil, contextutils.WrapErrorf(ErrOAuthCodeAlreadyUsed, "please try signing in again")
case "invalid_client":
return nil, contextutils.WrapError(ErrOAuthClientConfig, "")
case "invalid_request":
return nil, contextutils.WrapError(ErrOAuthInvalidRequest, "")
case "unauthorized_client":
return nil, contextutils.WrapError(ErrOAuthUnauthorized, "")
case "unsupported_grant_type":
return nil, contextutils.WrapError(ErrOAuthUnsupportedGrant, "")
default:
return nil, contextutils.WrapErrorf(contextutils.ErrOAuthProviderError, "OAuth error: %s - %s", errorResp.Error, errorResp.ErrorDescription)
}
}
return nil, contextutils.WrapErrorf(contextutils.ErrOAuthProviderError, "token exchange failed with status %d: %s", resp.StatusCode, string(body))
}
var tokenResp GoogleTokenResponse
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to decode token response")
}
span.SetAttributes(
attribute.String("oauth.token_type", tokenResp.TokenType),
attribute.Int("oauth.expires_in", tokenResp.ExpiresIn),
)
return &tokenResp, nil
}
// GetGoogleUserInfo retrieves user information from Google using the access token
func (s *OAuthService) GetGoogleUserInfo(ctx context.Context, accessToken string) (result0 *GoogleUserInfo, err error) {
ctx, span := observability.TraceOAuthFunction(ctx, "get_google_user_info",
attribute.String("oauth.userinfo_endpoint", s.UserInfoEndpoint),
)
defer observability.FinishSpan(span, &err)
userinfoURL := s.UserInfoEndpoint
if userinfoURL == "" {
userinfoURL = "https://www.googleapis.com/oauth2/v2/userinfo"
}
req, err := http.NewRequest("GET", userinfoURL, nil)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to create userinfo request")
}
req.Header.Set("Authorization", "Bearer "+accessToken)
// Use instrumented HTTP client for automatic tracing with explicit span options
client := &http.Client{
Timeout: config.OAuthHTTPTimeout,
Transport: otelhttp.NewTransport(http.DefaultTransport,
otelhttp.WithSpanOptions(trace.WithSpanKind(trace.SpanKindClient)),
),
}
resp, err := client.Do(req.WithContext(ctx))
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to get user info")
}
defer func() {
cerr := resp.Body.Close()
if cerr != nil {
s.logger.Warn(ctx, "Failed to close response body", map[string]interface{}{"error": cerr.Error()})
}
}()
span.SetAttributes(attribute.Int("http.status_code", resp.StatusCode))
if resp.StatusCode != http.StatusOK {
body, _ := io.ReadAll(resp.Body)
span.SetAttributes(attribute.String("error", fmt.Sprintf("userinfo request failed with status %d: %s", resp.StatusCode, string(body))))
return nil, contextutils.WrapErrorf(contextutils.ErrOAuthProviderError, "userinfo request failed with status %d: %s", resp.StatusCode, string(body))
}
var userInfo GoogleUserInfo
if err := json.NewDecoder(resp.Body).Decode(&userInfo); err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to decode user info")
}
span.SetAttributes(
attribute.String("user.email", userInfo.Email),
attribute.String("user.id", userInfo.ID),
attribute.Bool("user.verified_email", userInfo.VerifiedEmail),
)
return &userInfo, nil
}
// AuthenticateGoogleUser handles the complete Google OAuth flow
func (s *OAuthService) AuthenticateGoogleUser(ctx context.Context, code string, userService UserServiceInterface) (result0 *models.User, err error) {
ctx, span := observability.TraceOAuthFunction(ctx, "authenticate_google_user",
attribute.String("oauth.code", code),
)
defer observability.FinishSpan(span, &err)
// Exchange code for token
tokenResp, err := s.ExchangeCodeForToken(ctx, code)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to exchange code for token")
}
// Get user info from Google
userInfo, err := s.GetGoogleUserInfo(ctx, tokenResp.AccessToken)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to get user info")
}
span.SetAttributes(
attribute.String("user.email", userInfo.Email),
attribute.String("user.id", userInfo.ID),
)
// Check if user exists by email
existingUser, err := userService.GetUserByEmail(ctx, userInfo.Email)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to check existing user")
}
if existingUser != nil {
// User exists, return the user
span.SetAttributes(
attribute.Int("user.id", existingUser.ID),
attribute.String("auth.result", "existing_user"),
)
return existingUser, nil
}
// Check if signups are disabled before creating new user
if s.config != nil && s.config.IsSignupDisabled() {
// Check if OAuth signup is allowed via whitelist
if !s.config.IsOAuthSignupAllowed(userInfo.Email) {
span.SetAttributes(
attribute.String("auth.result", "oauth_signup_blocked"),
attribute.String("user.email", userInfo.Email),
)
return nil, ErrSignupsDisabled
}
// Allow OAuth signup for whitelisted email/domain
span.SetAttributes(
attribute.String("auth.result", "oauth_signup_allowed"),
attribute.String("user.email", userInfo.Email),
)
}
// User doesn't exist, create new user
// Use email as username (we'll handle conflicts)
username := userInfo.Email
email := userInfo.Email
// Check if username already exists, if so, append a number
counter := 1
for {
existingUser, err := userService.GetUserByUsername(ctx, username)
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to check username availability")
}
if existingUser == nil {
break
}
username = fmt.Sprintf("%s_%d", userInfo.Email, counter)
counter++
}
span.SetAttributes(
attribute.String("user.username", username),
attribute.String("user.email", email),
attribute.String("auth.result", "new_user"),
)
// Create user with default settings
// Use email as username (we'll handle conflicts)
user, err := userService.CreateUserWithEmailAndTimezone(ctx, username, email, "UTC", "italian", "beginner")
if err != nil {
span.SetAttributes(attribute.String("error", err.Error()))
return nil, contextutils.WrapError(err, "failed to create user")
}
span.SetAttributes(attribute.Int("user.id", user.ID))
return user, nil
}
package services
import (
"context"
"database/sql"
"errors"
"fmt"
"math/rand"
"strconv"
"strings"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
)
// QuestionServiceInterface defines the interface for question-related operations.
// This allows for easier mocking in tests.
type QuestionServiceInterface interface {
SaveQuestion(ctx context.Context, question *models.Question) error
AssignQuestionToUser(ctx context.Context, questionID, userID int) error
GetQuestionByID(ctx context.Context, id int) (*models.Question, error)
GetQuestionWithStats(ctx context.Context, id int) (*QuestionWithStats, error)
GetQuestionsByFilter(ctx context.Context, userID int, language, level string, questionType models.QuestionType, limit int) ([]models.Question, error)
GetNextQuestion(ctx context.Context, userID int, language, level string, qType models.QuestionType) (*QuestionWithStats, error)
GetAdaptiveQuestionsForDaily(ctx context.Context, userID int, language, level string, limit int) ([]*QuestionWithStats, error)
ReportQuestion(ctx context.Context, questionID, userID int, reportReason string) error
GetQuestionStats(ctx context.Context) (map[string]interface{}, error)
GetDetailedQuestionStats(ctx context.Context) (map[string]interface{}, error)
GetRecentQuestionContentsForUser(ctx context.Context, userID, limit int) ([]string, error)
GetReportedQuestions(ctx context.Context) ([]*ReportedQuestionWithUser, error)
MarkQuestionAsFixed(ctx context.Context, questionID int) error
UpdateQuestion(ctx context.Context, questionID int, content map[string]interface{}, correctAnswerIndex int, explanation string) error
DeleteQuestion(ctx context.Context, questionID int) error
GetUserQuestions(ctx context.Context, userID, limit int) ([]*models.Question, error)
GetUserQuestionsWithStats(ctx context.Context, userID, limit int) ([]*QuestionWithStats, error)
GetQuestionsPaginated(ctx context.Context, userID, page, pageSize int, search, typeFilter, statusFilter string) ([]*QuestionWithStats, int, error)
GetAllQuestionsPaginated(ctx context.Context, page, pageSize int, search, typeFilter, statusFilter, languageFilter, levelFilter string, userID *int) ([]*QuestionWithStats, int, error)
GetReportedQuestionsPaginated(ctx context.Context, page, pageSize int, search, typeFilter, languageFilter, levelFilter string) ([]*QuestionWithStats, int, error)
GetReportedQuestionsStats(ctx context.Context) (map[string]interface{}, error)
GetUserQuestionCount(ctx context.Context, userID int) (int, error)
GetUserResponseCount(ctx context.Context, userID int) (int, error)
GetRandomGlobalQuestionForUser(ctx context.Context, userID int, language, level string, qType models.QuestionType) (*QuestionWithStats, error)
GetUsersForQuestion(ctx context.Context, questionID int) ([]*models.User, int, error)
AssignUsersToQuestion(ctx context.Context, questionID int, userIDs []int) error
UnassignUsersFromQuestion(ctx context.Context, questionID int, userIDs []int) error
DB() *sql.DB
}
// QuestionService provides methods for question management.
type QuestionService struct {
db *sql.DB
learningService *LearningService
logger *observability.Logger
cfg *config.Config
}
// Shared query constants to eliminate duplication
const (
// questionSelectFields contains all question fields for SELECT queries
questionSelectFields = `id, type, language, level, difficulty_score, content, correct_answer, explanation, created_at, status, topic_category, grammar_focus, vocabulary_domain, scenario, style_modifier, difficulty_modifier, time_context`
)
// scanQuestionFromRow scans a database row into a models.Question struct
func (s *QuestionService) scanQuestionFromRow(row *sql.Row) (result0 *models.Question, err error) {
question := &models.Question{}
var contentJSON string
var topicCategory sql.NullString
var grammarFocus sql.NullString
var vocabularyDomain sql.NullString
var scenario sql.NullString
var styleModifier sql.NullString
var difficultyModifier sql.NullString
var timeContext sql.NullString
err = row.Scan(
&question.ID,
&question.Type,
&question.Language,
&question.Level,
&question.DifficultyScore,
&contentJSON,
&question.CorrectAnswer,
&question.Explanation,
&question.CreatedAt,
&question.Status,
&topicCategory,
&grammarFocus,
&vocabularyDomain,
&scenario,
&styleModifier,
&difficultyModifier,
&timeContext,
)
if err != nil {
return nil, err
}
// Set optional string fields if they have values
if topicCategory.Valid {
question.TopicCategory = topicCategory.String
}
if grammarFocus.Valid {
question.GrammarFocus = grammarFocus.String
}
if vocabularyDomain.Valid {
question.VocabularyDomain = vocabularyDomain.String
}
if scenario.Valid {
question.Scenario = scenario.String
}
if styleModifier.Valid {
question.StyleModifier = styleModifier.String
}
if difficultyModifier.Valid {
question.DifficultyModifier = difficultyModifier.String
}
if timeContext.Valid {
question.TimeContext = timeContext.String
}
if err := question.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
return question, nil
}
// scanQuestionFromRows scans a database rows into a models.Question struct
func (s *QuestionService) scanQuestionFromRows(rows *sql.Rows) (result0 *models.Question, err error) {
question := &models.Question{}
var contentJSON string
var topicCategory sql.NullString
var grammarFocus sql.NullString
var vocabularyDomain sql.NullString
var scenario sql.NullString
var styleModifier sql.NullString
var difficultyModifier sql.NullString
var timeContext sql.NullString
err = rows.Scan(
&question.ID,
&question.Type,
&question.Language,
&question.Level,
&question.DifficultyScore,
&contentJSON,
&question.CorrectAnswer,
&question.Explanation,
&question.CreatedAt,
&question.Status,
&topicCategory,
&grammarFocus,
&vocabularyDomain,
&scenario,
&styleModifier,
&difficultyModifier,
&timeContext,
)
if err != nil {
return nil, err
}
// Set optional string fields if they have values
if topicCategory.Valid {
question.TopicCategory = topicCategory.String
}
if grammarFocus.Valid {
question.GrammarFocus = grammarFocus.String
}
if vocabularyDomain.Valid {
question.VocabularyDomain = vocabularyDomain.String
}
if scenario.Valid {
question.Scenario = scenario.String
}
if styleModifier.Valid {
question.StyleModifier = styleModifier.String
}
if difficultyModifier.Valid {
question.DifficultyModifier = difficultyModifier.String
}
if timeContext.Valid {
question.TimeContext = timeContext.String
}
if err := question.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
return question, nil
}
// scanQuestionBasicFromRows scans a database rows into a models.Question struct (basic fields only)
func (s *QuestionService) scanQuestionBasicFromRows(rows *sql.Rows) (result0 *models.Question, err error) {
question := &models.Question{}
var contentJSON string
err = rows.Scan(
&question.ID,
&question.Type,
&question.Language,
&question.Level,
&question.DifficultyScore,
&contentJSON,
&question.CorrectAnswer,
&question.Explanation,
&question.CreatedAt,
&question.Status,
)
if err != nil {
return nil, err
}
if err := question.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
return question, nil
}
// scanQuestionWithStatsFromRows scans a database rows into a QuestionWithStats struct
func (s *QuestionService) scanQuestionWithStatsFromRows(rows *sql.Rows) (result0 *QuestionWithStats, err error) {
questionWithStats := &QuestionWithStats{
Question: &models.Question{},
}
var contentJSON string
err = rows.Scan(
&questionWithStats.ID,
&questionWithStats.Type,
&questionWithStats.Language,
&questionWithStats.Level,
&questionWithStats.DifficultyScore,
&contentJSON,
&questionWithStats.CorrectAnswer,
&questionWithStats.Explanation,
&questionWithStats.CreatedAt,
&questionWithStats.Status,
&questionWithStats.CorrectCount,
&questionWithStats.IncorrectCount,
&questionWithStats.TotalResponses,
&questionWithStats.UserCount,
)
if err != nil {
return nil, err
}
if err := questionWithStats.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
return questionWithStats, nil
}
// scanQuestionWithStatsAndAllFieldsFromRows scans a database rows into a QuestionWithStats struct (with all fields)
func (s *QuestionService) scanQuestionWithStatsAndAllFieldsFromRows(rows *sql.Rows) (result0 *QuestionWithStats, err error) {
questionWithStats := &QuestionWithStats{
Question: &models.Question{},
}
var contentJSON string
var topicCategory sql.NullString
var grammarFocus sql.NullString
var vocabularyDomain sql.NullString
var scenario sql.NullString
var styleModifier sql.NullString
var difficultyModifier sql.NullString
var timeContext sql.NullString
err = rows.Scan(
&questionWithStats.ID,
&questionWithStats.Type,
&questionWithStats.Language,
&questionWithStats.Level,
&questionWithStats.DifficultyScore,
&contentJSON,
&questionWithStats.CorrectAnswer,
&questionWithStats.Explanation,
&questionWithStats.CreatedAt,
&questionWithStats.Status,
&topicCategory,
&grammarFocus,
&vocabularyDomain,
&scenario,
&styleModifier,
&difficultyModifier,
&timeContext,
&questionWithStats.CorrectCount,
&questionWithStats.IncorrectCount,
&questionWithStats.TotalResponses,
&questionWithStats.UserCount,
)
if err != nil {
return nil, err
}
// Set optional string fields if they have values
if topicCategory.Valid {
questionWithStats.TopicCategory = topicCategory.String
}
if grammarFocus.Valid {
questionWithStats.GrammarFocus = grammarFocus.String
}
if vocabularyDomain.Valid {
questionWithStats.VocabularyDomain = vocabularyDomain.String
}
if scenario.Valid {
questionWithStats.Scenario = scenario.String
}
if styleModifier.Valid {
questionWithStats.StyleModifier = styleModifier.String
}
if difficultyModifier.Valid {
questionWithStats.DifficultyModifier = difficultyModifier.String
}
if timeContext.Valid {
questionWithStats.TimeContext = timeContext.String
}
if err := questionWithStats.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
return questionWithStats, nil
}
// scanQuestionWithPriorityAndStatsFromRows scans a database rows into a QuestionWithStats struct (with priority and stats)
func (s *QuestionService) scanQuestionWithPriorityAndStatsFromRows(rows *sql.Rows) (result0 *QuestionWithStats, err error) {
questionWithStats := &QuestionWithStats{
Question: &models.Question{},
}
var contentJSON string
var priorityScore float64
var timesAnswered int
var lastAnsweredAt sql.NullTime
var confidenceLevel sql.NullInt32
var topicCategory sql.NullString
var grammarFocus sql.NullString
var vocabularyDomain sql.NullString
var scenario sql.NullString
var styleModifier sql.NullString
var difficultyModifier sql.NullString
var timeContext sql.NullString
err = rows.Scan(
&questionWithStats.ID,
&questionWithStats.Type,
&questionWithStats.Language,
&questionWithStats.Level,
&questionWithStats.DifficultyScore,
&contentJSON,
&questionWithStats.CorrectAnswer,
&questionWithStats.Explanation,
&questionWithStats.CreatedAt,
&questionWithStats.Status,
&topicCategory,
&grammarFocus,
&vocabularyDomain,
&scenario,
&styleModifier,
&difficultyModifier,
&timeContext,
&priorityScore,
×Answered,
&lastAnsweredAt,
&questionWithStats.CorrectCount,
&questionWithStats.IncorrectCount,
&questionWithStats.TotalResponses,
&confidenceLevel,
)
if err != nil {
return nil, err
}
// Set optional string fields if they have values
if topicCategory.Valid {
questionWithStats.TopicCategory = topicCategory.String
}
if grammarFocus.Valid {
questionWithStats.GrammarFocus = grammarFocus.String
}
if vocabularyDomain.Valid {
questionWithStats.VocabularyDomain = vocabularyDomain.String
}
if scenario.Valid {
questionWithStats.Scenario = scenario.String
}
if styleModifier.Valid {
questionWithStats.StyleModifier = styleModifier.String
}
if difficultyModifier.Valid {
questionWithStats.DifficultyModifier = difficultyModifier.String
}
if timeContext.Valid {
questionWithStats.TimeContext = timeContext.String
}
if err := questionWithStats.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
// Set confidence level if it exists
if confidenceLevel.Valid {
level := int(confidenceLevel.Int32)
questionWithStats.ConfidenceLevel = &level
}
// Populate per-user times answered from the scanned value
questionWithStats.TimesAnswered = timesAnswered
return questionWithStats, nil
}
// scanQuestionWithStatsAndReportersFromRows scans a database rows into a QuestionWithStats struct (with reporter information)
func (s *QuestionService) scanQuestionWithStatsAndReportersFromRows(rows *sql.Rows) (result0 *QuestionWithStats, err error) {
questionWithStats := &QuestionWithStats{
Question: &models.Question{},
}
var contentJSON string
var reporters sql.NullString
var reportReasons sql.NullString
var topicCategory sql.NullString
var grammarFocus sql.NullString
var vocabularyDomain sql.NullString
var scenario sql.NullString
var styleModifier sql.NullString
var difficultyModifier sql.NullString
var timeContext sql.NullString
err = rows.Scan(
&questionWithStats.ID,
&questionWithStats.Type,
&questionWithStats.Language,
&questionWithStats.Level,
&questionWithStats.DifficultyScore,
&contentJSON,
&questionWithStats.CorrectAnswer,
&questionWithStats.Explanation,
&questionWithStats.CreatedAt,
&questionWithStats.Status,
&topicCategory,
&grammarFocus,
&vocabularyDomain,
&scenario,
&styleModifier,
&difficultyModifier,
&timeContext,
&questionWithStats.CorrectCount,
&questionWithStats.IncorrectCount,
&questionWithStats.TotalResponses,
&reporters,
&reportReasons,
)
if err != nil {
return nil, err
}
// Set optional string fields if they have values
if topicCategory.Valid {
questionWithStats.TopicCategory = topicCategory.String
}
if grammarFocus.Valid {
questionWithStats.GrammarFocus = grammarFocus.String
}
if vocabularyDomain.Valid {
questionWithStats.VocabularyDomain = vocabularyDomain.String
}
if scenario.Valid {
questionWithStats.Scenario = scenario.String
}
if styleModifier.Valid {
questionWithStats.StyleModifier = styleModifier.String
}
if difficultyModifier.Valid {
questionWithStats.DifficultyModifier = difficultyModifier.String
}
if timeContext.Valid {
questionWithStats.TimeContext = timeContext.String
}
if err := questionWithStats.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
// Store reporter information
if reporters.Valid && reporters.String != "" {
questionWithStats.Reporters = reporters.String
}
// Store report reasons information
if reportReasons.Valid && reportReasons.String != "" {
questionWithStats.ReportReasons = reportReasons.String
}
return questionWithStats, nil
}
// getQuestionByQuery is a shared method for getting a question by any query
func (s *QuestionService) getQuestionByQuery(ctx context.Context, query string, args ...interface{}) (result0 *models.Question, err error) {
row := s.db.QueryRowContext(ctx, query, args...)
var question *models.Question
question, err = s.scanQuestionFromRow(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, sql.ErrNoRows // Propagate sql.ErrNoRows for not found
}
return nil, err
}
return question, nil
}
// NewQuestionServiceWithLogger creates a new QuestionService instance with logger
func NewQuestionServiceWithLogger(db *sql.DB, learningService *LearningService, cfg *config.Config, logger *observability.Logger) *QuestionService {
if db == nil {
panic("database connection cannot be nil")
}
if logger == nil {
panic("logger cannot be nil")
}
return &QuestionService{
db: db,
learningService: learningService,
logger: logger,
cfg: cfg,
}
}
// getDailyRepeatAvoidDays returns the configured number of days to avoid repeating
// questions in daily assignments. Defaults to 7 when not configured or invalid.
func (s *QuestionService) getDailyRepeatAvoidDays() int {
if s.cfg != nil {
if days := s.cfg.Server.DailyRepeatAvoidDays; days > 0 {
return days
}
}
return 7
}
// SaveQuestion saves a question to the database
func (s *QuestionService) SaveQuestion(ctx context.Context, question *models.Question) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "save_question", observability.AttributeQuestion(question))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
var contentJSON []byte
contentJSONStr, err := question.MarshalContentToJSON()
if err != nil {
return contextutils.WrapError(err, "failed to marshal question content")
}
contentJSON = []byte(contentJSONStr)
if err != nil {
return contextutils.WrapError(err, "failed to marshal question content")
}
if question.Status == "" {
question.Status = models.QuestionStatusActive
}
query := `
INSERT INTO questions (type, language, level, difficulty_score, content, correct_answer, explanation, status, topic_category, grammar_focus, vocabulary_domain, scenario, style_modifier, difficulty_modifier, time_context)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15) RETURNING id
`
var id int
err = s.db.QueryRowContext(ctx, query,
question.Type,
question.Language,
question.Level,
question.DifficultyScore,
string(contentJSON),
question.CorrectAnswer,
question.Explanation,
question.Status,
question.TopicCategory,
question.GrammarFocus,
question.VocabularyDomain,
question.Scenario,
question.StyleModifier,
question.DifficultyModifier,
question.TimeContext,
).Scan(&id)
if err != nil {
return contextutils.WrapError(err, "failed to save question to database")
}
question.ID = id
return nil
}
// AssignQuestionToUser assigns a question to a user
func (s *QuestionService) AssignQuestionToUser(ctx context.Context, questionID, userID int) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "assign_question_to_user", observability.AttributeQuestionID(questionID), observability.AttributeUserID(userID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
INSERT INTO user_questions (user_id, question_id)
VALUES ($1, $2)
ON CONFLICT (user_id, question_id) DO NOTHING
`
_, err = s.db.ExecContext(ctx, query, userID, questionID)
return contextutils.WrapError(err, "failed to assign question to user")
}
// GetQuestionByID retrieves a question by its ID
func (s *QuestionService) GetQuestionByID(ctx context.Context, id int) (result0 *models.Question, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_question_by_id", observability.AttributeQuestionID(id))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := fmt.Sprintf("SELECT %s FROM questions WHERE id = $1", questionSelectFields)
return s.getQuestionByQuery(ctx, query, id)
}
// GetQuestionWithStats retrieves a question by its ID with response statistics
func (s *QuestionService) GetQuestionWithStats(ctx context.Context, id int) (result0 *QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_question_with_stats", observability.AttributeQuestionID(id))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
q.id, q.type, q.language, q.level, q.difficulty_score,
q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
COALESCE(SUM(CASE WHEN ur.is_correct = true THEN 1 ELSE 0 END), 0) as correct_count,
COALESCE(SUM(CASE WHEN ur.is_correct = false THEN 1 ELSE 0 END), 0) as incorrect_count,
COALESCE(COUNT(ur.id), 0) as total_responses
FROM questions q
LEFT JOIN user_responses ur ON q.id = ur.question_id
WHERE q.id = $1
GROUP BY q.id, q.type, q.language, q.level, q.difficulty_score,
q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context
`
q := &models.Question{}
stats := &QuestionWithStats{Question: q}
var contentJSON string
err = s.db.QueryRowContext(ctx, query, id).Scan(
&q.ID, &q.Type, &q.Language, &q.Level, &q.DifficultyScore,
&contentJSON, &q.CorrectAnswer, &q.Explanation, &q.CreatedAt, &q.Status,
&q.TopicCategory, &q.GrammarFocus, &q.VocabularyDomain, &q.Scenario, &q.StyleModifier, &q.DifficultyModifier, &q.TimeContext,
&stats.CorrectCount, &stats.IncorrectCount, &stats.TotalResponses,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, contextutils.ErrQuestionNotFound
}
return nil, contextutils.WrapError(err, "failed to get question with stats")
}
// Parse JSON content
if err := q.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, contextutils.WrapError(err, "failed to unmarshal question content")
}
return stats, nil
}
// GetQuestionsByFilter retrieves questions matching the specified criteria
func (s *QuestionService) GetQuestionsByFilter(ctx context.Context, userID int, language, level string, questionType models.QuestionType, limit int) (result0 []models.Question, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_questions_by_filter", observability.AttributeUserID(userID), observability.AttributeLanguage(language), observability.AttributeLevel(level), observability.AttributeQuestionType(questionType))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
var query string
var args []interface{}
if questionType == "" {
// Don't filter by type if questionType is empty
query = `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
WHERE uq.user_id = $1 AND q.language = $2 AND q.level = $3 AND q.status = $4
ORDER BY RANDOM()
LIMIT $5
`
args = []interface{}{userID, language, level, models.QuestionStatusActive, limit}
} else {
// Filter by specific type
query = `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
WHERE uq.user_id = $1 AND q.language = $2 AND q.level = $3 AND q.type = $4 AND q.status = $5
ORDER BY RANDOM()
LIMIT $6
`
args = []interface{}{userID, language, level, questionType, models.QuestionStatusActive, limit}
}
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query questions by filter")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []models.Question
for rows.Next() {
question, err := s.scanQuestionBasicFromRows(rows)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan question from rows")
}
questions = append(questions, *question)
}
return questions, nil
}
// ReportedQuestionWithUser represents a reported question with user information
type ReportedQuestionWithUser struct {
*models.Question
ReportedByUsername string `json:"reported_by_username"`
TotalResponses int `json:"total_responses"`
}
// GetReportedQuestions retrieves all questions that have been reported as problematic
func (s *QuestionService) GetReportedQuestions(ctx context.Context) (result0 []*ReportedQuestionWithUser, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_reported_questions")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status, u.username,
COALESCE(COUNT(ur.id), 0) as total_responses
FROM questions q
LEFT JOIN user_questions uq ON q.id = uq.question_id
LEFT JOIN users u ON uq.user_id = u.id
LEFT JOIN user_responses ur ON q.id = ur.question_id
WHERE q.status = $1
GROUP BY q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status, u.username
ORDER BY q.created_at DESC
`
var rows *sql.Rows
rows, err = s.db.QueryContext(ctx, query, models.QuestionStatusReported)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query reported questions")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []*ReportedQuestionWithUser
for rows.Next() {
var question models.Question
var reportedByUsername sql.NullString
var contentJSON string
var totalResponses int
err = rows.Scan(
&question.ID,
&question.Type,
&question.Language,
&question.Level,
&question.DifficultyScore,
&contentJSON,
&question.CorrectAnswer,
&question.Explanation,
&question.CreatedAt,
&question.Status,
&reportedByUsername,
&totalResponses,
)
if err != nil {
return nil, err
}
if err := question.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
username := ""
if reportedByUsername.Valid {
username = reportedByUsername.String
}
reportedQuestion := &ReportedQuestionWithUser{
Question: &question,
ReportedByUsername: username,
TotalResponses: totalResponses,
}
questions = append(questions, reportedQuestion)
}
return questions, nil
}
// MarkQuestionAsFixed marks a reported question as fixed and puts it back in rotation
func (s *QuestionService) MarkQuestionAsFixed(ctx context.Context, questionID int) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "mark_question_as_fixed", observability.AttributeQuestionID(questionID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `UPDATE questions SET status = $1 WHERE id = $2`
var result sql.Result
result, err = s.db.ExecContext(ctx, query, models.QuestionStatusActive, questionID)
if err != nil {
return contextutils.WrapError(err, "failed to mark question as fixed")
}
// Check if the question was actually updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "question with ID %d not found", questionID)
}
return nil
}
// UpdateQuestion updates a question's content, correct answer, and explanation
func (s *QuestionService) UpdateQuestion(ctx context.Context, questionID int, content map[string]interface{}, correctAnswerIndex int, explanation string) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "update_question", observability.AttributeQuestionID(questionID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
var contentJSON []byte
// Marshal provided content map via a temporary Question instance to reuse method
tempQ := &models.Question{Content: content}
contentJSONStr, err := tempQ.MarshalContentToJSON()
if err != nil {
return contextutils.WrapError(err, "failed to marshal content JSON")
}
contentJSON = []byte(contentJSONStr)
if err != nil {
return contextutils.WrapError(err, "failed to marshal content JSON")
}
query := `UPDATE questions SET content = $1, correct_answer = $2, explanation = $3 WHERE id = $4`
var result sql.Result
result, err = s.db.ExecContext(ctx, query, string(contentJSON), correctAnswerIndex, explanation, questionID)
if err != nil {
return contextutils.WrapError(err, "failed to update question")
}
// Check if the question was actually updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "question with ID %d not found", questionID)
}
return nil
}
// DeleteQuestion permanently deletes a question from the database
func (s *QuestionService) DeleteQuestion(ctx context.Context, questionID int) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "delete_question", observability.AttributeQuestionID(questionID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// First, delete associated user responses
deleteResponsesQuery := `DELETE FROM user_responses WHERE question_id = $1`
_, err = s.db.ExecContext(ctx, deleteResponsesQuery, questionID)
if err != nil {
return contextutils.WrapError(err, "failed to delete associated user responses")
}
// Then delete the question itself
deleteQuestionQuery := `DELETE FROM questions WHERE id = $1`
var result sql.Result
result, err = s.db.ExecContext(ctx, deleteQuestionQuery, questionID)
if err != nil {
return contextutils.WrapError(err, "failed to delete question")
}
// Check if the question was actually deleted
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "question with ID %d not found", questionID)
}
return nil
}
// ReportQuestion marks a question as reported/problematic by a specific user
func (s *QuestionService) ReportQuestion(ctx context.Context, questionID, userID int, reportReason string) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "report_question", observability.AttributeQuestionID(questionID), observability.AttributeUserID(userID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Start a transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn(ctx, "Failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}
}()
// Check if question exists first
var questionExists bool
err = tx.QueryRowContext(ctx, `SELECT EXISTS(SELECT 1 FROM questions WHERE id = $1)`, questionID).Scan(&questionExists)
if err != nil {
return contextutils.WrapError(err, "failed to check if question exists")
}
if !questionExists {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "question with id %d not found", questionID)
}
// Update question status to reported
updateQuery := `UPDATE questions SET status = $1 WHERE id = $2`
var result sql.Result
result, err = tx.ExecContext(ctx, updateQuery, models.QuestionStatusReported, questionID)
if err != nil {
return contextutils.WrapError(err, "failed to update question status")
}
// Check if the question was actually updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "question with ID %d not found", questionID)
}
// Use provided report reason or default message
reason := reportReason
if reason == "" {
reason = "Question reported by user"
}
// Create or update a report record: if the same user reports the same question again,
// update the report_reason to the new value instead of doing nothing. Also update created_at
// so admin views show the time of the latest report by that user.
reportQuery := `INSERT INTO question_reports (question_id, reported_by_user_id, report_reason) VALUES ($1, $2, $3) ON CONFLICT (question_id, reported_by_user_id) DO UPDATE SET report_reason = EXCLUDED.report_reason, created_at = now()`
_, err = tx.ExecContext(ctx, reportQuery, questionID, userID, reason)
if err != nil {
return contextutils.WrapError(err, "failed to create question report")
}
// Commit the transaction
err = tx.Commit()
if err != nil {
return contextutils.WrapError(err, "failed to commit transaction")
}
return nil
}
// GetNextQuestion gets the next question for a user based on usage count and availability
func (s *QuestionService) GetNextQuestion(ctx context.Context, userID int, language, level string, qType models.QuestionType) (result0 *QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_next_question", observability.AttributeUserID(userID), observability.AttributeLanguage(language), observability.AttributeLevel(level), observability.AttributeQuestionType(qType))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Use priority-based selection with stats included
return s.getNextQuestionWithPriority(ctx, userID, language, level, qType)
}
// getNextQuestionWithPriority implements priority-based question selection with stats
func (s *QuestionService) getNextQuestionWithPriority(ctx context.Context, userID int, language, level string, qType models.QuestionType) (result0 *QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_next_question_with_priority", observability.AttributeUserID(userID), observability.AttributeLanguage(language), observability.AttributeLevel(level), observability.AttributeQuestionType(qType))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Get user preferences
var prefs *models.UserLearningPreferences
prefs, err = s.learningService.GetUserLearningPreferences(ctx, userID)
if err != nil {
s.logger.Warn(ctx, "Failed to get user preferences", map[string]interface{}{"user_id": userID, "error": err.Error()})
// Fall back to default preferences
prefs = s.learningService.GetDefaultLearningPreferences()
}
// Get available questions with priority scores and stats
var questions []*QuestionWithStats
questions, err = s.getAvailableQuestionsWithPriority(ctx, userID, language, level, qType, prefs)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get available questions")
}
if len(questions) == 0 {
// Fallback: try to get a random global question and assign it to the user
globalQ, err := s.GetRandomGlobalQuestionForUser(ctx, userID, language, level, qType)
if err != nil {
return nil, contextutils.WrapError(err, "no personalized questions, and failed to get global fallback question")
}
if globalQ != nil {
return globalQ, nil
}
return nil, nil // No questions available at all
}
// Apply FreshQuestionRatio logic (NEW)
selectedQuestion, err := s.selectQuestionWithFreshnessRatio(questions, prefs.FreshQuestionRatio)
if err != nil {
return nil, contextutils.WrapError(err, "failed to select question with freshness ratio")
}
// Return the selected question with stats (already included)
return selectedQuestion, nil
}
// GetAdaptiveQuestionsForDaily selects multiple adaptive questions for daily assignments
func (s *QuestionService) GetAdaptiveQuestionsForDaily(ctx context.Context, userID int, language, level string, limit int) (result0 []*QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_adaptive_questions_for_daily")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Get user learning preferences
prefs, err := s.learningService.GetUserLearningPreferences(ctx, userID)
if err != nil {
s.logger.Warn(ctx, "Failed to get user learning preferences, using defaults", map[string]interface{}{
"user_id": userID, "error": err.Error(),
})
prefs = &models.UserLearningPreferences{
FreshQuestionRatio: 0.7,
}
}
var selectedQuestions []*QuestionWithStats
selectedQuestionIDs := make(map[int]bool) // Track selected question IDs to prevent duplicates
// Select questions across different types to provide variety
questionTypes := []models.QuestionType{models.Vocabulary, models.FillInBlank, models.QuestionAnswer, models.ReadingComprehension}
// Calculate how many questions to select from each type
questionsPerType := limit / len(questionTypes)
remainingQuestions := limit % len(questionTypes)
for i, qType := range questionTypes {
// Calculate how many questions to get for this type
currentLimit := questionsPerType
if i < remainingQuestions {
currentLimit++ // Distribute remaining questions evenly
}
if currentLimit == 0 {
continue
}
// Get available questions for DAILY with 2-day recent-correct exclusion
questions, err := s.getAvailableQuestionsForDailyWithPriority(ctx, userID, language, level, qType, prefs)
if err != nil {
s.logger.Warn(ctx, "Failed to get questions for type", map[string]interface{}{
"user_id": userID, "type": qType, "error": err.Error(),
})
continue
}
// Filter out questions that have already been selected
var availableQuestions []*QuestionWithStats
for _, q := range questions {
if !selectedQuestionIDs[q.ID] {
availableQuestions = append(availableQuestions, q)
}
}
if len(availableQuestions) == 0 {
// Try to get a global fallback question for this type
globalQ, err := s.GetRandomGlobalQuestionForUser(ctx, userID, language, level, qType)
if err != nil {
s.logger.Warn(ctx, "Failed to get global fallback question", map[string]interface{}{
"user_id": userID, "type": qType, "error": err.Error(),
})
continue
}
if globalQ != nil && !selectedQuestionIDs[globalQ.ID] {
selectedQuestions = append(selectedQuestions, globalQ)
selectedQuestionIDs[globalQ.ID] = true
s.logger.Info(ctx, "Added global fallback question", map[string]interface{}{
"user_id": userID, "type": qType, "question_id": globalQ.ID,
})
}
continue
}
// Select questions for this type using adaptive selection
s.logger.Info(ctx, "Starting selection for question type", map[string]interface{}{
"user_id": userID, "type": qType, "current_limit": currentLimit, "available_questions": len(availableQuestions),
})
questionsSelected := 0
remainingQuestionsForType := availableQuestions
for j := 0; j < currentLimit && len(remainingQuestionsForType) > 0; j++ {
// Apply freshness ratio logic for each selection
selectedQuestion, err := s.selectQuestionWithFreshnessRatio(remainingQuestionsForType, prefs.FreshQuestionRatio)
if err != nil {
s.logger.Warn(ctx, "Failed to select question with freshness ratio", map[string]interface{}{
"user_id": userID, "type": qType, "error": err.Error(),
})
// Fallback to simple random selection
if len(remainingQuestionsForType) > 0 {
selectedQuestion = remainingQuestionsForType[rand.Intn(len(remainingQuestionsForType))]
} else {
break
}
}
if selectedQuestion != nil && !selectedQuestionIDs[selectedQuestion.ID] {
selectedQuestions = append(selectedQuestions, selectedQuestion)
selectedQuestionIDs[selectedQuestion.ID] = true
questionsSelected++
// Remove the selected question from the remaining pool
var newRemainingQuestions []*QuestionWithStats
for _, q := range remainingQuestionsForType {
if q.ID != selectedQuestion.ID {
newRemainingQuestions = append(newRemainingQuestions, q)
}
}
remainingQuestionsForType = newRemainingQuestions
s.logger.Info(ctx, "Successfully selected question", map[string]interface{}{
"user_id": userID, "type": qType, "iteration": j, "question_id": selectedQuestion.ID,
"total_selected": len(selectedQuestions),
})
} else {
s.logger.Warn(ctx, "Failed to select question for type", map[string]interface{}{
"user_id": userID, "type": qType, "iteration": j, "current_limit": currentLimit,
"selected_question_nil": selectedQuestion == nil,
"already_selected": selectedQuestion != nil && selectedQuestionIDs[selectedQuestion.ID],
})
// Remove the question from the pool even if it was already selected
if selectedQuestion != nil {
var newRemainingQuestions []*QuestionWithStats
for _, q := range remainingQuestionsForType {
if q.ID != selectedQuestion.ID {
newRemainingQuestions = append(newRemainingQuestions, q)
}
}
remainingQuestionsForType = newRemainingQuestions
}
}
}
// If we didn't select enough questions for this type, try simple selection from all available questions
if questionsSelected < currentLimit {
s.logger.Info(ctx, "Using simple selection to fill remaining slots", map[string]interface{}{
"user_id": userID, "type": qType, "questions_selected": questionsSelected, "current_limit": currentLimit,
})
// Get all questions for this type again and filter out already selected ones
allQuestionsForType, err := s.getAvailableQuestionsForDailyWithPriority(ctx, userID, language, level, qType, prefs)
if err == nil {
for _, q := range allQuestionsForType {
if !selectedQuestionIDs[q.ID] && questionsSelected < currentLimit {
selectedQuestions = append(selectedQuestions, q)
selectedQuestionIDs[q.ID] = true
questionsSelected++
}
}
}
}
s.logger.Info(ctx, "Completed selection for question type", map[string]interface{}{
"user_id": userID, "type": qType, "questions_selected": questionsSelected, "target": currentLimit,
})
}
// If we don't have enough questions, fill with random questions from any type
if len(selectedQuestions) < limit {
remainingNeeded := limit - len(selectedQuestions)
s.logger.Info(ctx, "Not enough questions from type-based selection, using fallback", map[string]interface{}{
"user_id": userID, "selected_count": len(selectedQuestions), "limit": limit, "remaining_needed": remainingNeeded,
})
// Get all available questions by trying each question type
var allQuestions []*QuestionWithStats
questionIDMap := make(map[int]bool) // Track seen question IDs to avoid duplicates
for _, qType := range questionTypes {
questions, err := s.getAvailableQuestionsForDailyWithPriority(ctx, userID, language, level, qType, prefs)
if err == nil {
for _, q := range questions {
if !questionIDMap[q.ID] && !selectedQuestionIDs[q.ID] {
allQuestions = append(allQuestions, q)
questionIDMap[q.ID] = true
}
}
}
}
s.logger.Info(ctx, "Fallback questions available", map[string]interface{}{
"user_id": userID, "all_questions_count": len(allQuestions),
})
if len(allQuestions) > 0 {
// Select random questions to fill the remaining slots
for i := 0; i < remainingNeeded && i < len(allQuestions); i++ {
selectedQuestion, err := s.selectQuestionWithFreshnessRatio(allQuestions, prefs.FreshQuestionRatio)
if err != nil {
s.logger.Warn(ctx, "Failed to select question with freshness ratio in fallback", map[string]interface{}{
"user_id": userID, "error": err.Error(),
})
// Fallback to simple random selection
if len(allQuestions) > 0 {
selectedQuestion = allQuestions[rand.Intn(len(allQuestions))]
} else {
break
}
}
if selectedQuestion != nil && !selectedQuestionIDs[selectedQuestion.ID] {
selectedQuestions = append(selectedQuestions, selectedQuestion)
selectedQuestionIDs[selectedQuestion.ID] = true
// Remove the selected question from the pool
var newAllQuestions []*QuestionWithStats
for _, q := range allQuestions {
if q.ID != selectedQuestion.ID {
newAllQuestions = append(newAllQuestions, q)
}
}
allQuestions = newAllQuestions
} else if selectedQuestion != nil {
// Remove the question from the pool even if it was already selected
var newAllQuestions []*QuestionWithStats
for _, q := range allQuestions {
if q.ID != selectedQuestion.ID {
newAllQuestions = append(newAllQuestions, q)
}
}
allQuestions = newAllQuestions
}
}
}
}
// Ensure we don't exceed the limit
if len(selectedQuestions) > limit {
selectedQuestions = selectedQuestions[:limit]
}
// Final duplicate check - this should never happen but provides extra safety
finalSelectedQuestions := make([]*QuestionWithStats, 0, len(selectedQuestions))
finalSelectedIDs := make(map[int]bool)
for _, q := range selectedQuestions {
if !finalSelectedIDs[q.ID] {
finalSelectedQuestions = append(finalSelectedQuestions, q)
finalSelectedIDs[q.ID] = true
} else {
s.logger.Warn(ctx, "Duplicate question detected in final selection", map[string]interface{}{
"user_id": userID, "question_id": q.ID,
})
}
}
s.logger.Info(ctx, "Selected adaptive questions for daily assignment", map[string]interface{}{
"user_id": userID,
"language": language,
"level": level,
"requested_limit": limit,
"selected_count": len(finalSelectedQuestions),
"duplicates_removed": len(selectedQuestions) - len(finalSelectedQuestions),
})
return finalSelectedQuestions, nil
}
// GetQuestionStats returns basic statistics about questions in the system
func (s *QuestionService) GetQuestionStats(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_question_stats")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
stats := make(map[string]interface{})
// Total questions
var totalQuestions int
err = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM questions").Scan(&totalQuestions)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get total questions count")
}
stats["total_questions"] = totalQuestions
// Questions by type
typeQuery := `
SELECT type, COUNT(*) as count
FROM questions
GROUP BY type
`
rows, err := s.db.QueryContext(ctx, typeQuery)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query questions by type")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
questionsByType := make(map[string]int)
for rows.Next() {
var qType string
var count int
if err := rows.Scan(&qType, &count); err != nil {
return nil, contextutils.WrapError(err, "failed to scan question type count")
}
questionsByType[qType] = count
}
stats["questions_by_type"] = questionsByType
// Questions by level
levelQuery := `
SELECT level, COUNT(*) as count
FROM questions
GROUP BY level
`
rows, err = s.db.QueryContext(ctx, levelQuery)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query questions by level")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
questionsByLevel := make(map[string]int)
for rows.Next() {
var level string
var count int
if err := rows.Scan(&level, &count); err != nil {
return nil, err
}
questionsByLevel[level] = count
}
stats["questions_by_level"] = questionsByLevel
return stats, nil
}
// GetDetailedQuestionStats returns detailed statistics about questions
func (s *QuestionService) GetDetailedQuestionStats(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_detailed_question_stats")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
stats := make(map[string]interface{})
// Total questions
var totalQuestions int
err = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM questions").Scan(&totalQuestions)
if err != nil {
return nil, err
}
stats["total_questions"] = totalQuestions
// Questions by language, level, and type combination
detailQuery := `
SELECT language, level, type, COUNT(*) as count
FROM questions
GROUP BY language, level, type
ORDER BY language, level, type
`
rows, err := s.db.QueryContext(ctx, detailQuery)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
// Create nested structure: language -> level -> type -> count
questionsByDetail := make(map[string]map[string]map[string]int)
for rows.Next() {
var language, level, qType string
var count int
if err := rows.Scan(&language, &level, &qType, &count); err != nil {
return nil, err
}
if questionsByDetail[language] == nil {
questionsByDetail[language] = make(map[string]map[string]int)
}
if questionsByDetail[language][level] == nil {
questionsByDetail[language][level] = make(map[string]int)
}
questionsByDetail[language][level][qType] = count
}
stats["questions_by_detail"] = questionsByDetail
// Questions by language
languageQuery := `
SELECT language, COUNT(*) as count
FROM questions
GROUP BY language
`
rows, err = s.db.QueryContext(ctx, languageQuery)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
questionsByLanguage := make(map[string]int)
for rows.Next() {
var language string
var count int
if err := rows.Scan(&language, &count); err != nil {
return nil, err
}
questionsByLanguage[language] = count
}
stats["questions_by_language"] = questionsByLanguage
// Questions by type
typeQuery := `
SELECT type, COUNT(*) as count
FROM questions
GROUP BY type
`
rows, err = s.db.QueryContext(ctx, typeQuery)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
questionsByType := make(map[string]int)
for rows.Next() {
var qType string
var count int
if err := rows.Scan(&qType, &count); err != nil {
return nil, err
}
questionsByType[qType] = count
}
stats["questions_by_type"] = questionsByType
// Questions by level
levelQuery := `
SELECT level, COUNT(*) as count
FROM questions
GROUP BY level
`
rows, err = s.db.QueryContext(ctx, levelQuery)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
questionsByLevel := make(map[string]int)
for rows.Next() {
var level string
var count int
if err := rows.Scan(&level, &count); err != nil {
return nil, err
}
questionsByLevel[level] = count
}
stats["questions_by_level"] = questionsByLevel
return stats, nil
}
// GetRecentQuestionContentsForUser retrieves recent question contents for a user
func (s *QuestionService) GetRecentQuestionContentsForUser(ctx context.Context, userID, limit int) (result0 []string, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_recent_question_contents_for_user", observability.AttributeUserID(userID), observability.AttributeLimit(limit))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT DISTINCT q.content
FROM user_responses ur
JOIN questions q ON ur.question_id = q.id
JOIN user_questions uq ON q.id = uq.question_id
WHERE ur.user_id = $1 AND uq.user_id = $2
ORDER BY q.content DESC
LIMIT $3
`
var rows *sql.Rows
rows, err = s.db.QueryContext(ctx, query, userID, userID, limit)
if err != nil {
return []string{}, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var contents []string
for rows.Next() {
var content string
if err := rows.Scan(&content); err != nil {
return []string{}, err
}
contents = append(contents, content)
}
// Ensure we always return an empty slice instead of nil
if contents == nil {
contents = []string{}
}
return contents, nil
}
// GetUserQuestions retrieves actual questions for a user (not just content)
func (s *QuestionService) GetUserQuestions(ctx context.Context, userID, limit int) (result0 []*models.Question, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_user_questions", observability.AttributeUserID(userID), observability.AttributeLimit(limit))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status, q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
WHERE uq.user_id = $1
ORDER BY q.created_at DESC
LIMIT $2
`
var rows *sql.Rows
rows, err = s.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []*models.Question
for rows.Next() {
question, err := s.scanQuestionFromRows(rows)
if err != nil {
return nil, err
}
questions = append(questions, question)
}
return questions, nil
}
// GetUserQuestionsWithStats retrieves questions for a user with response statistics
func (s *QuestionService) GetUserQuestionsWithStats(ctx context.Context, userID, limit int) (result0 []*QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_user_questions_with_stats", observability.AttributeUserID(userID), observability.AttributeLimit(limit))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT
q.id, q.type, q.language, q.level, q.difficulty_score,
q.content, q.correct_answer, q.explanation, q.created_at, q.status,
COALESCE(SUM(CASE WHEN ur.is_correct = true THEN 1 ELSE 0 END), 0) as correct_count,
COALESCE(SUM(CASE WHEN ur.is_correct = false THEN 1 ELSE 0 END), 0) as incorrect_count,
COALESCE(COUNT(ur.id), 0) as total_responses,
COALESCE(uq_stats.user_count, 0) as user_count
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
LEFT JOIN user_responses ur ON q.id = ur.question_id
LEFT JOIN (
SELECT
question_id,
COUNT(*) as user_count
FROM user_questions
GROUP BY question_id
) uq_stats ON q.id = uq_stats.question_id
WHERE uq.user_id = $1
GROUP BY q.id, q.type, q.language, q.level, q.difficulty_score,
q.content, q.correct_answer, q.explanation, q.created_at, q.status,
uq_stats.user_count
ORDER BY q.created_at DESC
LIMIT $2
`
rows, err := s.db.QueryContext(ctx, query, userID, limit)
if err != nil {
return nil, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []*QuestionWithStats
for rows.Next() {
questionWithStats, err := s.scanQuestionWithStatsFromRows(rows)
if err != nil {
return nil, err
}
questions = append(questions, questionWithStats)
}
if err = rows.Err(); err != nil {
return nil, err
}
return questions, nil
}
// QuestionWithStats represents a question with response statistics
type QuestionWithStats struct {
*models.Question
CorrectCount int `json:"correct_count"`
IncorrectCount int `json:"incorrect_count"`
TotalResponses int `json:"total_responses"`
// TimesAnswered tracks how many times THIS user answered the question (per-user)
TimesAnswered int `json:"times_answered"`
UserCount int `json:"user_count"`
Reporters string `json:"reporters,omitempty"`
ReportReasons string `json:"report_reasons,omitempty"`
ConfidenceLevel *int `json:"confidence_level,omitempty"`
}
// GetQuestionsPaginated retrieves questions with pagination and response statistics
func (s *QuestionService) GetQuestionsPaginated(ctx context.Context, userID, page, pageSize int, search, typeFilter, statusFilter string) (result0 []*QuestionWithStats, result1 int, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_questions_paginated", observability.AttributeUserID(userID), observability.AttributePage(page), observability.AttributePageSize(pageSize), observability.AttributeSearch(search), observability.AttributeTypeFilter(typeFilter), observability.AttributeStatusFilter(statusFilter))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Build WHERE clause with filters using parameterized queries
whereConditions := []string{"uq.user_id = $1"}
args := []interface{}{userID}
argCount := 1
// Add search filter
if search != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("(q.content::text ILIKE $%d OR q.explanation ILIKE $%d)", argCount, argCount))
args = append(args, "%"+search+"%")
}
// Add type filter
if typeFilter != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("q.type = $%d", argCount))
args = append(args, typeFilter)
}
// Add status filter
if statusFilter != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("q.status = $%d", argCount))
args = append(args, statusFilter)
}
// Join all conditions
whereClause := "WHERE " + strings.Join(whereConditions, " AND ")
// First get the total count with filters
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM questions q JOIN user_questions uq ON q.id = uq.question_id %s", whereClause)
var totalCount int
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalCount)
if err != nil {
return nil, 0, err
}
// Calculate offset
offset := (page - 1) * pageSize
// Build main query with pagination
query := fmt.Sprintf(`
SELECT
q.id, q.type, q.language, q.level, q.difficulty_score,
q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
COALESCE(SUM(CASE WHEN ur.is_correct = true THEN 1 ELSE 0 END), 0) as correct_count,
COALESCE(SUM(CASE WHEN ur.is_correct = false THEN 1 ELSE 0 END), 0) as incorrect_count,
COALESCE(COUNT(ur.id), 0) as total_responses,
COALESCE(uq_stats.user_count, 0) as user_count
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
LEFT JOIN user_responses ur ON q.id = ur.question_id
LEFT JOIN (
SELECT
question_id,
COUNT(*) as user_count
FROM user_questions
GROUP BY question_id
) uq_stats ON q.id = uq_stats.question_id
%s
GROUP BY q.id, q.type, q.language, q.level, q.difficulty_score,
q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
uq_stats.user_count
ORDER BY q.id DESC
LIMIT $%d OFFSET $%d
`, whereClause, argCount+1, argCount+2)
// Add pagination parameters
args = append(args, pageSize, offset)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, 0, err
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []*QuestionWithStats
for rows.Next() {
questionWithStats, err := s.scanQuestionWithStatsAndAllFieldsFromRows(rows)
if err != nil {
return nil, 0, err
}
questions = append(questions, questionWithStats)
}
if err = rows.Err(); err != nil {
return nil, 0, err
}
return questions, totalCount, nil
}
// PRIORITY-BASED QUESTION SELECTION METHODS
// getAvailableQuestionsWithPriority retrieves available questions with priority scores and stats
func (s *QuestionService) getAvailableQuestionsWithPriority(ctx context.Context, userID int, language, level string, qType models.QuestionType, _ *models.UserLearningPreferences) (result0 []*QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_available_questions_with_priority", observability.AttributeUserID(userID), observability.AttributeLanguage(language), observability.AttributeLevel(level), observability.AttributeQuestionType(qType))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Build SQL query with priority scoring and stats
query := `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
COALESCE(qps.priority_score, 100.0) as priority_score,
COALESCE(uq_stats.times_answered, 0) as times_answered,
uq_stats.last_answered_at,
COALESCE(stats.correct_count, 0) as correct_count,
COALESCE(stats.incorrect_count, 0) as incorrect_count,
COALESCE(stats.total_responses, 0) as total_responses,
uqm.confidence_level
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
LEFT JOIN question_priority_scores qps ON q.id = qps.question_id AND qps.user_id = $1
LEFT JOIN (
SELECT question_id,
COUNT(*) as times_answered,
MAX(created_at) as last_answered_at
FROM user_responses
WHERE user_id = $1
GROUP BY question_id
) uq_stats ON q.id = uq_stats.question_id
LEFT JOIN (
SELECT
question_id,
COUNT(CASE WHEN is_correct = true THEN 1 END) as correct_count,
COUNT(CASE WHEN is_correct = false THEN 1 END) as incorrect_count,
COUNT(*) as total_responses
FROM user_responses
GROUP BY question_id
) stats ON q.id = stats.question_id
LEFT JOIN user_question_metadata uqm ON q.id = uqm.question_id AND uqm.user_id = $1
WHERE uq.user_id = $1
AND q.language = $2
AND q.level = $3
AND q.type = $4
AND q.status = 'active'
AND q.id NOT IN (
SELECT ur.question_id
FROM user_responses ur
WHERE ur.user_id = $1
AND ur.created_at > NOW() - INTERVAL '1 hour'
)
-- Exclude questions where the user's last 3 responses were all correct within the last 90 days
AND NOT EXISTS (
SELECT 1 FROM (
SELECT ur2.is_correct
FROM user_responses ur2
WHERE ur2.user_id = $1
AND ur2.question_id = q.id
AND ur2.created_at >= NOW() - INTERVAL '90 days'
ORDER BY ur2.created_at DESC
LIMIT 3
) recent_three
WHERE (SELECT COUNT(*) FROM (
SELECT 1 FROM (
SELECT ur3.is_correct
FROM user_responses ur3
WHERE ur3.user_id = $1
AND ur3.question_id = q.id
AND ur3.created_at >= NOW() - INTERVAL '90 days'
ORDER BY ur3.created_at DESC
LIMIT 3
) t WHERE t.is_correct = TRUE
) c) = 3
)
-- Exclude questions the user explicitly marked as known with max confidence (5)
-- within the last 60 days (approx. 2 months)
AND NOT EXISTS (
SELECT 1 FROM user_question_metadata uqm2
WHERE uqm2.user_id = $1
AND uqm2.question_id = q.id
AND uqm2.marked_as_known = TRUE
AND uqm2.confidence_level = 5
AND uqm2.marked_as_known_at >= NOW() - INTERVAL '60 days'
)
ORDER BY priority_score DESC, RANDOM()
LIMIT 50
`
rows, err := s.db.QueryContext(ctx, query, userID, language, level, qType)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to query questions: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []*QuestionWithStats
for rows.Next() {
questionWithStats, err := s.scanQuestionWithPriorityAndStatsFromRows(rows)
if err != nil {
s.logger.Error(ctx, "Error scanning question", err, map[string]interface{}{})
continue // Skip malformed rows
}
questions = append(questions, questionWithStats)
}
return questions, nil
}
// getAvailableQuestionsForDailyWithPriority applies daily-specific eligibility:
// exclude questions answered correctly within the last 2 days for the user.
func (s *QuestionService) getAvailableQuestionsForDailyWithPriority(ctx context.Context, userID int, language, level string, qType models.QuestionType, _ *models.UserLearningPreferences) (result0 []*QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_available_questions_for_daily_with_priority", observability.AttributeUserID(userID), observability.AttributeLanguage(language), observability.AttributeLevel(level), observability.AttributeQuestionType(qType))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
avoidDays := s.getDailyRepeatAvoidDays()
query := `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
COALESCE(qps.priority_score, 100.0) as priority_score,
COALESCE(uq_stats.times_answered, 0) as times_answered,
uq_stats.last_answered_at,
COALESCE(stats.correct_count, 0) as correct_count,
COALESCE(stats.incorrect_count, 0) as incorrect_count,
COALESCE(stats.total_responses, 0) as total_responses,
uqm.confidence_level
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
LEFT JOIN question_priority_scores qps ON q.id = qps.question_id AND qps.user_id = $1
LEFT JOIN (
SELECT question_id,
COUNT(*) as times_answered,
MAX(created_at) as last_answered_at
FROM user_responses
WHERE user_id = $1
GROUP BY question_id
) uq_stats ON q.id = uq_stats.question_id
LEFT JOIN (
SELECT
question_id,
COUNT(CASE WHEN is_correct = true THEN 1 END) as correct_count,
COUNT(CASE WHEN is_correct = false THEN 1 END) as incorrect_count,
COUNT(*) as total_responses
FROM user_responses
GROUP BY question_id
) stats ON q.id = stats.question_id
LEFT JOIN user_question_metadata uqm ON q.id = uqm.question_id AND uqm.user_id = $1
WHERE uq.user_id = $1
AND q.language = $2
AND q.level = $3
AND q.type = $4
AND q.status = 'active'
AND NOT EXISTS (
SELECT 1
FROM user_responses ur
WHERE ur.user_id = $1
AND ur.question_id = q.id
AND ur.is_correct = TRUE
AND ur.created_at >= NOW() - ($5 || ' days')::interval
)
-- Exclude questions the user marked as known with confidence 5 within last 60 days
AND NOT EXISTS (
SELECT 1 FROM user_question_metadata uqm2
WHERE uqm2.user_id = $1
AND uqm2.question_id = q.id
AND uqm2.marked_as_known = TRUE
AND uqm2.confidence_level = 5
AND uqm2.marked_as_known_at >= NOW() - INTERVAL '60 days'
)
ORDER BY priority_score DESC, RANDOM()
LIMIT 50
`
rows, err := s.db.QueryContext(ctx, query, userID, language, level, qType, avoidDays)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to query questions (daily): %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var questions []*QuestionWithStats
for rows.Next() {
questionWithStats, err := s.scanQuestionWithPriorityAndStatsFromRows(rows)
if err != nil {
s.logger.Error(ctx, "Error scanning question (daily)", err, map[string]interface{}{})
continue
}
questions = append(questions, questionWithStats)
}
return questions, nil
}
// selectQuestionWithWeightedRandomness selects a question using weighted random selection
func (s *QuestionService) selectQuestionWithWeightedRandomness(questions []*QuestionWithStats) (result0 *QuestionWithStats, err error) {
if len(questions) == 0 {
return nil, contextutils.WrapError(contextutils.ErrRecordNotFound, "no questions available")
}
// Use weighted random selection based on usage count (lower = higher priority)
totalWeight := 0.0
for _, q := range questions {
// Prefer per-user times answered when available
usageCount := q.TotalResponses
if q.TimesAnswered >= 0 {
usageCount = q.TimesAnswered
}
// Lower usage count = higher weight
weight := 1.0 / (float64(usageCount) + 1.0)
totalWeight += weight
}
// Handle edge case where all questions have zero weight or floating-point precision issues
if totalWeight <= 0 {
// If all questions have equal weight (e.g., all TotalResponses = 0), use simple random selection
return questions[rand.Intn(len(questions))], nil
}
target := rand.Float64() * totalWeight
currentWeight := 0.0
for _, q := range questions {
usageCount := q.TotalResponses
if q.TimesAnswered >= 0 {
usageCount = q.TimesAnswered
}
weight := 1.0 / (float64(usageCount) + 1.0)
currentWeight += weight
if currentWeight >= target {
return q, nil
}
}
// Fallback: if we reach the end without selecting (due to floating-point precision),
// return the last question or a random one
if len(questions) > 0 {
return questions[len(questions)-1], nil
}
return nil, contextutils.WrapError(contextutils.ErrInternalError, "failed to select question with weighted randomness")
}
// selectQuestionWithFreshnessRatio selects a question based on freshness ratio
func (s *QuestionService) selectQuestionWithFreshnessRatio(questions []*QuestionWithStats, freshnessRatio float64) (result0 *QuestionWithStats, err error) {
if len(questions) == 0 {
return nil, contextutils.WrapError(contextutils.ErrRecordNotFound, "no questions available")
}
// Separate fresh and review questions based on total responses
var freshQuestions []*QuestionWithStats
var reviewQuestions []*QuestionWithStats
for _, q := range questions {
// Consider fresh relative to this user (TimesAnswered==0). Fall back to TotalResponses if TimesAnswered not set.
isFresh := false
if q.TimesAnswered >= 0 {
isFresh = q.TimesAnswered == 0
} else {
isFresh = q.TotalResponses == 0
}
if isFresh {
freshQuestions = append(freshQuestions, q)
} else {
reviewQuestions = append(reviewQuestions, q)
}
}
// Use probabilistic selection based on the freshness ratio
var selectedQuestions []*QuestionWithStats
if len(freshQuestions) > 0 && len(reviewQuestions) > 0 {
// Both categories available - use probabilistic selection
if rand.Float64() < freshnessRatio {
selectedQuestions = freshQuestions
} else {
selectedQuestions = reviewQuestions
}
} else if len(freshQuestions) > 0 {
// Only fresh questions available
selectedQuestions = freshQuestions
} else if len(reviewQuestions) > 0 {
// Only review questions available
selectedQuestions = reviewQuestions
} else {
// Fallback to all questions if no separation possible
selectedQuestions = questions
}
if len(selectedQuestions) == 0 {
return nil, contextutils.WrapError(contextutils.ErrRecordNotFound, "no questions available after freshness filtering")
}
// Use weighted random selection within the chosen category
result, err := s.selectQuestionWithWeightedRandomness(selectedQuestions)
if err != nil {
// Log debug info about the selection failure
s.logger.Warn(context.Background(), "selectQuestionWithWeightedRandomness failed", map[string]interface{}{
"total_questions": len(questions),
"fresh_questions": len(freshQuestions),
"review_questions": len(reviewQuestions),
"selected_category_size": len(selectedQuestions),
"freshness_ratio": freshnessRatio,
"error": err.Error(),
})
}
return result, err
}
// GetUserQuestionCount returns the total number of questions available for a user
func (s *QuestionService) GetUserQuestionCount(ctx context.Context, userID int) (result0 int, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_user_question_count", observability.AttributeUserID(userID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT COUNT(DISTINCT q.id)
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
WHERE uq.user_id = $1 AND q.status = 'active'
`
var count int
err = s.db.QueryRowContext(ctx, query, userID).Scan(&count)
if err != nil {
return 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get user question count: %w", err)
}
return count, nil
}
// GetUserResponseCount returns the total number of responses for a user
func (s *QuestionService) GetUserResponseCount(ctx context.Context, userID int) (result0 int, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_user_response_count", observability.AttributeUserID(userID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `SELECT COUNT(*) FROM user_responses WHERE user_id = $1`
var count int
err = s.db.QueryRowContext(ctx, query, userID).Scan(&count)
if err != nil {
return 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get user response count: %w", err)
}
return count, nil
}
// GetUsersForQuestion returns the users assigned to a question, up to 5 users, and the total count
func (s *QuestionService) GetUsersForQuestion(ctx context.Context, questionID int) (result0 []*models.User, result1 int, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_users_for_question", observability.AttributeQuestionID(questionID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// First get the total count
countQuery := `SELECT COUNT(*) FROM user_questions WHERE question_id = $1`
var totalCount int
err = s.db.QueryRowContext(ctx, countQuery, questionID).Scan(&totalCount)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get user count for question: %w", err)
}
// Then get up to 5 users
usersQuery := `
SELECT u.id, u.username, u.email, u.timezone, u.password_hash, u.last_active,
u.preferred_language, u.current_level, u.ai_provider, u.ai_model,
u.ai_enabled, u.ai_api_key, u.created_at, u.updated_at
FROM users u
JOIN user_questions uq ON u.id = uq.user_id
WHERE uq.question_id = $1
ORDER BY u.username
LIMIT 5
`
rows, err := s.db.QueryContext(ctx, usersQuery, questionID)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get users for question: %w", err)
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Warn(ctx, "Failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var users []*models.User
for rows.Next() {
user := &models.User{}
err = rows.Scan(
&user.ID,
&user.Username,
&user.Email,
&user.Timezone,
&user.PasswordHash,
&user.LastActive,
&user.PreferredLanguage,
&user.CurrentLevel,
&user.AIProvider,
&user.AIModel,
&user.AIEnabled,
&user.AIAPIKey,
&user.CreatedAt,
&user.UpdatedAt,
)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to scan user: %w", err)
}
users = append(users, user)
}
if err = rows.Err(); err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "error iterating users: %w", err)
}
// Ensure we always return an empty slice instead of nil
if users == nil {
users = make([]*models.User, 0)
}
return users, totalCount, nil
}
// Helper: scan a *sql.Row into a QuestionWithStats (for single-row queries)
func (s *QuestionService) scanQuestionWithPriorityAndStatsFromRow(row *sql.Row) (result0 *QuestionWithStats, err error) {
questionWithStats := &QuestionWithStats{
Question: &models.Question{},
}
var contentJSON string
var priorityScore float64
var timesAnswered int
var lastAnsweredAt sql.NullTime
err = row.Scan(
&questionWithStats.ID,
&questionWithStats.Type,
&questionWithStats.Language,
&questionWithStats.Level,
&questionWithStats.DifficultyScore,
&contentJSON,
&questionWithStats.CorrectAnswer,
&questionWithStats.Explanation,
&questionWithStats.CreatedAt,
&questionWithStats.Status,
&questionWithStats.TopicCategory,
&questionWithStats.GrammarFocus,
&questionWithStats.VocabularyDomain,
&questionWithStats.Scenario,
&questionWithStats.StyleModifier,
&questionWithStats.DifficultyModifier,
&questionWithStats.TimeContext,
&priorityScore,
×Answered,
&lastAnsweredAt,
&questionWithStats.CorrectCount,
&questionWithStats.IncorrectCount,
&questionWithStats.TotalResponses,
)
if err != nil {
return nil, err
}
if err := questionWithStats.UnmarshalContentFromJSON(contentJSON); err != nil {
return nil, err
}
return questionWithStats, nil
}
// GetRandomGlobalQuestionForUser finds a random question from the global pool for the given language, level, and type that is not already assigned to the user, assigns it, and returns it.
func (s *QuestionService) GetRandomGlobalQuestionForUser(ctx context.Context, userID int, language, level string, qType models.QuestionType) (result0 *QuestionWithStats, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_random_global_question_for_user", observability.AttributeUserID(userID), observability.AttributeLanguage(language), observability.AttributeLevel(level), observability.AttributeQuestionType(qType))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
100.0 as priority_score, 0 as times_answered, NULL as last_answered_at, 0 as correct_count, 0 as incorrect_count, 0 as total_responses
FROM questions q
WHERE q.language = $1
AND q.level = $2
AND q.type = $3
AND q.status = 'active'
AND q.id NOT IN (
SELECT uq.question_id
FROM user_questions uq
WHERE uq.user_id = $4
)
-- Exclude questions the user marked as known with confidence 5 within last 60 days
AND NOT EXISTS (
SELECT 1 FROM user_question_metadata uqm2
WHERE uqm2.user_id = $4
AND uqm2.question_id = q.id
AND uqm2.marked_as_known = TRUE
AND uqm2.confidence_level = 5
AND uqm2.marked_as_known_at >= NOW() - INTERVAL '60 days'
)
ORDER BY RANDOM()
LIMIT 1
`
row := s.db.QueryRowContext(ctx, query, language, level, qType, userID)
questionWithStats, err := s.scanQuestionWithPriorityAndStatsFromRow(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil // No global questions available
}
return nil, err
}
// Assign the question to the user
err = s.AssignQuestionToUser(ctx, questionWithStats.ID, userID)
if err != nil {
s.logger.Warn(ctx, "Failed to assign global question to user", map[string]interface{}{"question_id": questionWithStats.ID, "user_id": userID, "error": err.Error()})
// Still return the question, but log the error
}
return questionWithStats, nil
}
// GetAllQuestionsPaginated returns all questions with pagination and filtering
func (s *QuestionService) GetAllQuestionsPaginated(ctx context.Context, page, pageSize int, search, typeFilter, statusFilter, languageFilter, levelFilter string, userID *int) (result0 []*QuestionWithStats, result1 int, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_all_questions_paginated")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Build the base query
baseQuery := `
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
COALESCE(ur_stats.correct_count, 0) as correct_count,
COALESCE(ur_stats.incorrect_count, 0) as incorrect_count,
COALESCE(ur_stats.total_responses, 0) as total_responses,
COALESCE(uq_stats.user_count, 0) as user_count
FROM questions q
LEFT JOIN (
SELECT
question_id,
COUNT(CASE WHEN is_correct = true THEN 1 END) as correct_count,
COUNT(CASE WHEN is_correct = false THEN 1 END) as incorrect_count,
COUNT(*) as total_responses
FROM user_responses
GROUP BY question_id
) ur_stats ON q.id = ur_stats.question_id
LEFT JOIN (
SELECT
question_id,
COUNT(*) as user_count
FROM user_questions
GROUP BY question_id
) uq_stats ON q.id = uq_stats.question_id
WHERE 1=1
`
// Build the count query
countQuery := `
SELECT COUNT(*)
FROM questions q
WHERE 1=1
`
var args []interface{}
argIndex := 1
// Add filters
if search != "" {
searchCondition := ` AND (q.content::text ILIKE $` + strconv.Itoa(argIndex) + ` OR q.explanation ILIKE $` + strconv.Itoa(argIndex) + `)`
baseQuery += searchCondition
countQuery += searchCondition
args = append(args, "%"+search+"%")
argIndex++
}
if typeFilter != "" {
typeCondition := ` AND q.type = $` + strconv.Itoa(argIndex)
baseQuery += typeCondition
countQuery += typeCondition
args = append(args, typeFilter)
argIndex++
}
if statusFilter != "" {
statusCondition := ` AND q.status = $` + strconv.Itoa(argIndex)
baseQuery += statusCondition
countQuery += statusCondition
args = append(args, statusFilter)
argIndex++
}
if languageFilter != "" {
languageCondition := ` AND q.language = $` + strconv.Itoa(argIndex)
baseQuery += languageCondition
countQuery += languageCondition
args = append(args, languageFilter)
argIndex++
}
if levelFilter != "" {
levelCondition := ` AND q.level = $` + strconv.Itoa(argIndex)
baseQuery += levelCondition
countQuery += levelCondition
args = append(args, levelFilter)
argIndex++
}
if userID != nil {
userCondition := ` AND q.id IN (SELECT question_id FROM user_questions WHERE user_id = $` + strconv.Itoa(argIndex) + `)`
baseQuery += userCondition
countQuery += userCondition
args = append(args, *userID)
argIndex++
}
// Get total count
var total int
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get total count: %w", err)
}
// Add pagination
offset := (page - 1) * pageSize
baseQuery += ` ORDER BY q.created_at DESC LIMIT $` + strconv.Itoa(argIndex) + ` OFFSET $` + strconv.Itoa(argIndex+1)
args = append(args, pageSize, offset)
// Execute the main query
rows, err := s.db.QueryContext(ctx, baseQuery, args...)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get questions: %w", err)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
var questions []*QuestionWithStats
for rows.Next() {
question, err := s.scanQuestionWithStatsAndAllFieldsFromRows(rows)
if err != nil {
return nil, 0, err
}
questions = append(questions, question)
}
return questions, total, nil
}
// GetReportedQuestionsPaginated returns reported questions with pagination and filtering
func (s *QuestionService) GetReportedQuestionsPaginated(ctx context.Context, page, pageSize int, search, typeFilter, languageFilter, levelFilter string) (result0 []*QuestionWithStats, result1 int, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_reported_questions_paginated")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Validate pagination parameters
if page < 1 {
page = 1
}
if pageSize < 1 {
pageSize = 10
}
// Build WHERE clause with filters using parameterized queries
whereConditions := []string{"q.status = 'reported'"}
args := []interface{}{}
argCount := 0
// Add search filter
if search != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("(q.content::text ILIKE $%d OR q.explanation ILIKE $%d)", argCount, argCount))
args = append(args, "%"+search+"%")
}
// Add type filter
if typeFilter != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("q.type = $%d", argCount))
args = append(args, typeFilter)
}
// Add language filter
if languageFilter != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("q.language = $%d", argCount))
args = append(args, languageFilter)
}
// Add level filter
if levelFilter != "" {
argCount++
whereConditions = append(whereConditions, fmt.Sprintf("q.level = $%d", argCount))
args = append(args, levelFilter)
}
// Join all conditions
whereClause := "WHERE " + strings.Join(whereConditions, " AND ")
// Build the count query
countQuery := fmt.Sprintf("SELECT COUNT(DISTINCT q.id) FROM questions q %s", whereClause)
var total int
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get total count: %w", err)
}
// Calculate offset
offset := (page - 1) * pageSize
// Build main query with pagination
query := fmt.Sprintf(`
SELECT q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
COALESCE(ur_stats.correct_count, 0) as correct_count,
COALESCE(ur_stats.incorrect_count, 0) as incorrect_count,
COALESCE(ur_stats.total_responses, 0) as total_responses,
STRING_AGG(DISTINCT u.username, ', ') as reporters,
STRING_AGG(DISTINCT qr.report_reason, ' | ') as report_reasons
FROM questions q
LEFT JOIN (
SELECT
question_id,
COUNT(CASE WHEN is_correct = true THEN 1 END) as correct_count,
COUNT(CASE WHEN is_correct = false THEN 1 END) as incorrect_count,
COUNT(*) as total_responses
FROM user_responses
GROUP BY question_id
) ur_stats ON q.id = ur_stats.question_id
LEFT JOIN question_reports qr ON q.id = qr.question_id
LEFT JOIN users u ON qr.reported_by_user_id = u.id
%s
GROUP BY q.id, q.type, q.language, q.level, q.difficulty_score, q.content, q.correct_answer, q.explanation, q.created_at, q.status,
q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario, q.style_modifier, q.difficulty_modifier, q.time_context,
ur_stats.correct_count, ur_stats.incorrect_count, ur_stats.total_responses
ORDER BY q.created_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argCount+1, argCount+2)
// Add pagination parameters
args = append(args, pageSize, offset)
// Execute the main query
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, 0, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get reported questions: %w", err)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
var questions []*QuestionWithStats
for rows.Next() {
question, err := s.scanQuestionWithStatsAndReportersFromRows(rows)
if err != nil {
return nil, 0, err
}
questions = append(questions, question)
}
return questions, total, nil
}
// GetReportedQuestionsStats returns statistics about reported questions
func (s *QuestionService) GetReportedQuestionsStats(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "get_reported_questions_stats")
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
stats := make(map[string]interface{})
// Get total reported questions
var totalReported int
err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM questions WHERE status = 'reported'`).Scan(&totalReported)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get total reported questions: %w", err)
}
stats["total_reported"] = totalReported
// Get reported questions by type
rows, err := s.db.QueryContext(ctx, `
SELECT type, COUNT(*) as count
FROM questions
WHERE status = 'reported'
GROUP BY type
ORDER BY count DESC
`)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get reported questions by type: %w", err)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
reportedByType := make(map[string]int)
for rows.Next() {
var questionType string
var count int
if err := rows.Scan(&questionType, &count); err != nil {
return nil, err
}
reportedByType[questionType] = count
}
stats["reported_by_type"] = reportedByType
// Get reported questions by level
rows, err = s.db.QueryContext(ctx, `
SELECT level, COUNT(*) as count
FROM questions
WHERE status = 'reported'
GROUP BY level
ORDER BY count DESC
`)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get reported questions by level: %w", err)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
reportedByLevel := make(map[string]int)
for rows.Next() {
var level string
var count int
if err := rows.Scan(&level, &count); err != nil {
return nil, err
}
reportedByLevel[level] = count
}
stats["reported_by_level"] = reportedByLevel
// Get reported questions by language
rows, err = s.db.QueryContext(ctx, `
SELECT language, COUNT(*) as count
FROM questions
WHERE status = 'reported'
GROUP BY language
ORDER BY count DESC
`)
if err != nil {
return nil, contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "failed to get reported questions by language: %w", err)
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
reportedByLanguage := make(map[string]int)
for rows.Next() {
var language string
var count int
if err := rows.Scan(&language, &count); err != nil {
return nil, err
}
reportedByLanguage[language] = count
}
stats["reported_by_language"] = reportedByLanguage
return stats, nil
}
// AssignUsersToQuestion assigns multiple users to a question
func (s *QuestionService) AssignUsersToQuestion(ctx context.Context, questionID int, userIDs []int) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "assign_users_to_question", observability.AttributeQuestionID(questionID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Start a transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn(ctx, "Failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}
}()
// Prepare the insert statement
stmt, err := tx.PrepareContext(ctx, `
INSERT INTO user_questions (user_id, question_id, created_at)
VALUES ($1, $2, NOW())
ON CONFLICT (user_id, question_id) DO NOTHING
`)
if err != nil {
return contextutils.WrapError(err, "failed to prepare insert statement")
}
defer func() {
if closeErr := stmt.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close statement", map[string]interface{}{"error": closeErr.Error()})
}
}()
// Insert each user-question mapping
for _, userID := range userIDs {
_, err = stmt.ExecContext(ctx, userID, questionID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to assign user %d to question %d", userID, questionID)
}
}
// Commit the transaction
err = tx.Commit()
if err != nil {
return contextutils.WrapError(err, "failed to commit transaction")
}
return nil
}
// UnassignUsersFromQuestion removes multiple users from a question
func (s *QuestionService) UnassignUsersFromQuestion(ctx context.Context, questionID int, userIDs []int) (err error) {
ctx, span := observability.TraceQuestionFunction(ctx, "unassign_users_from_question", observability.AttributeQuestionID(questionID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Start a transaction
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn(ctx, "Failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}
}()
// Prepare the delete statement
stmt, err := tx.PrepareContext(ctx, `
DELETE FROM user_questions
WHERE user_id = $1 AND question_id = $2
`)
if err != nil {
return contextutils.WrapError(err, "failed to prepare delete statement")
}
defer func() {
if closeErr := stmt.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close statement", map[string]interface{}{"error": closeErr.Error()})
}
}()
// Delete each user-question mapping
for _, userID := range userIDs {
_, err = stmt.ExecContext(ctx, userID, questionID)
if err != nil {
return contextutils.WrapErrorf(err, "failed to unassign user %d from question %d", userID, questionID)
}
}
// Commit the transaction
err = tx.Commit()
if err != nil {
return contextutils.WrapError(err, "failed to commit transaction")
}
return nil
}
// DB returns the underlying *sql.DB instance
func (s *QuestionService) DB() *sql.DB {
return s.db
}
// Package services provides business logic services for the quiz application.
package services
import (
"context"
"database/sql"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
// TestEmailService implements the Mailer interface for testing purposes
// It doesn't actually send emails but logs the operations and records them in the database
type TestEmailService struct {
cfg *config.Config
logger *observability.Logger
db *sql.DB
}
// NewTestEmailService creates a new TestEmailService instance
func NewTestEmailService(cfg *config.Config, logger *observability.Logger) *TestEmailService {
return &TestEmailService{
cfg: cfg,
logger: logger,
}
}
// NewTestEmailServiceWithDB creates a new TestEmailService instance with database connection
func NewTestEmailServiceWithDB(cfg *config.Config, logger *observability.Logger, db *sql.DB) *TestEmailService {
return &TestEmailService{
cfg: cfg,
logger: logger,
db: db,
}
}
// SendDailyReminder sends a daily reminder email to a user (test mode - just logs)
func (e *TestEmailService) SendDailyReminder(ctx context.Context, user *models.User) error {
ctx, span := otel.Tracer("test-email-service").Start(ctx, "SendDailyReminder",
trace.WithAttributes(
attribute.Int("user.id", user.ID),
attribute.String("user.email", user.Email.String),
),
)
defer span.End()
if !user.Email.Valid || user.Email.String == "" {
e.logger.Warn(ctx, "User has no email address, skipping daily reminder", map[string]interface{}{
"user_id": user.ID,
})
return nil
}
// Generate email data (same as real service) - not used in test mode but kept for consistency
_ = map[string]interface{}{
"Username": user.Username,
"QuizAppURL": e.cfg.Server.AppBaseURL,
"CurrentDate": time.Now().Format("January 2, 2006"),
"DailyGoal": 10,
"StreakDays": 5,
"TotalQuestions": 150,
"Level": "B1",
"Language": "Italian",
}
// Log the email operation instead of sending. Use the same subject as the
// real service to avoid confusion, but do NOT record a second entry in the
// database here â recording is handled by caller to ensure a single source
// of truth for sent notifications.
e.logger.Info(ctx, "TEST MODE: Would send daily reminder email", map[string]interface{}{
"user_id": user.ID,
"email": user.Email.String,
"template": "daily_reminder",
"subject": "Time for your daily quiz! ð",
"test_mode": true,
})
return nil
}
// SendEmail sends a generic email with the given parameters (test mode - just logs)
func (e *TestEmailService) SendEmail(ctx context.Context, to, subject, templateName string, data map[string]interface{}) error {
ctx, span := otel.Tracer("test-email-service").Start(ctx, "SendEmail",
trace.WithAttributes(
attribute.String("email.to", to),
attribute.String("email.subject", subject),
attribute.String("email.template", templateName),
),
)
defer span.End()
// Log the email operation instead of sending
e.logger.Info(ctx, "TEST MODE: Would send email", map[string]interface{}{
"to": to,
"subject": subject,
"template": templateName,
"test_mode": true,
"data_keys": getMapKeys(data),
})
// Record the notification in the database if we have a DB connection
if e.db != nil {
// For test emails, we don't have a user ID, so we'll use 0
err := e.RecordSentNotification(ctx, 0, "test_email", subject, templateName, "sent", "")
if err != nil {
e.logger.Error(ctx, "Failed to record test notification", err, map[string]interface{}{
"to": to,
"template": templateName,
})
}
}
return nil
}
// RecordSentNotification records a sent notification in the database
func (e *TestEmailService) RecordSentNotification(ctx context.Context, userID int, notificationType, subject, templateName, status, errorMessage string) error {
ctx, span := otel.Tracer("test-email-service").Start(ctx, "RecordSentNotification",
trace.WithAttributes(
attribute.Int("user.id", userID),
attribute.String("notification.type", notificationType),
attribute.String("notification.status", status),
),
)
defer span.End()
if e.db == nil {
e.logger.Warn(ctx, "No database connection available for recording notification", map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
})
return nil
}
query := `
INSERT INTO sent_notifications (user_id, notification_type, subject, template_name, sent_at, status, error_message)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`
_, err := e.db.ExecContext(ctx, query, userID, notificationType, subject, templateName, time.Now(), status, errorMessage)
if err != nil {
span.RecordError(err)
e.logger.Error(ctx, "Failed to record sent notification", err, map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
"status": status,
})
return contextutils.WrapError(err, "failed to record sent notification")
}
e.logger.Info(ctx, "Recorded sent notification", map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
"status": status,
})
return nil
}
// IsEnabled returns whether email functionality is enabled (always true for test service)
func (e *TestEmailService) IsEnabled() bool {
return true
}
// getMapKeys returns the keys of a map as a slice of strings
func getMapKeys(data map[string]interface{}) []string {
keys := make([]string, 0, len(data))
for k := range data {
keys = append(keys, k)
}
return keys
}
//go:build integration
// +build integration
package services
import (
"context"
"database/sql"
"os"
"testing"
"quizapp/internal/config"
"quizapp/internal/database"
"quizapp/internal/observability"
"github.com/stretchr/testify/require"
)
// SharedTestDBSetup provides a clean, isolated database for each integration test
// Uses the optimized CleanupTestDatabase function for consistent cleanup
func SharedTestDBSetup(t *testing.T) *sql.DB {
observabilityLogger := observability.NewLogger(&config.OpenTelemetryConfig{EnableLogging: false})
dbManager := database.NewManager(observabilityLogger)
// Require TEST_DATABASE_URL environment variable to be set
databaseURL := os.Getenv("TEST_DATABASE_URL")
if databaseURL == "" {
t.Fatal("TEST_DATABASE_URL environment variable must be set for integration tests")
}
db, err := dbManager.InitDB(databaseURL)
require.NoError(t, err)
// Use the optimized cleanup function
CleanupTestDatabase(db, t)
return db
}
// cleanupDatabase performs the core database cleanup operations
// This is the shared implementation used by both CleanupTestDatabase and SharedTestSuite.Cleanup
func cleanupDatabase(db *sql.DB, logger *observability.Logger) {
ctx := context.Background()
tx, err := db.BeginTx(ctx, nil)
if err != nil {
if logger != nil {
logger.Error(ctx, "Failed to begin cleanup transaction", err)
}
return
}
defer func() {
if err != nil {
tx.Rollback()
}
}()
// Fast cleanup with batched operations
cleanupQueries := []string{
"TRUNCATE TABLE user_responses CASCADE",
"TRUNCATE TABLE performance_metrics CASCADE",
"TRUNCATE TABLE user_question_metadata CASCADE",
"TRUNCATE TABLE question_priority_scores CASCADE",
"TRUNCATE TABLE user_learning_preferences CASCADE",
"TRUNCATE TABLE user_questions CASCADE",
"TRUNCATE TABLE questions CASCADE",
"TRUNCATE TABLE worker_status CASCADE",
"TRUNCATE TABLE worker_settings CASCADE",
"TRUNCATE TABLE user_api_keys CASCADE",
"TRUNCATE TABLE user_roles CASCADE",
"TRUNCATE TABLE question_reports CASCADE",
"TRUNCATE TABLE notification_errors CASCADE",
"TRUNCATE TABLE upcoming_notifications CASCADE",
"TRUNCATE TABLE sent_notifications CASCADE",
"TRUNCATE TABLE daily_question_assignments CASCADE",
"TRUNCATE TABLE users CASCADE",
}
for _, query := range cleanupQueries {
_, err := tx.ExecContext(ctx, query)
if err != nil {
if logger != nil {
logger.Warn(ctx, "Could not execute cleanup query", map[string]interface{}{
"query": query,
})
}
}
}
// Reset sequences
sequenceQueries := []string{
"ALTER SEQUENCE users_id_seq RESTART WITH 1",
"ALTER SEQUENCE questions_id_seq RESTART WITH 1",
"ALTER SEQUENCE user_responses_id_seq RESTART WITH 1",
"ALTER SEQUENCE performance_metrics_id_seq RESTART WITH 1",
}
for _, query := range sequenceQueries {
_, err := tx.ExecContext(ctx, query)
if err != nil {
if logger != nil {
logger.Warn(ctx, "Could not reset sequence", map[string]interface{}{
"query": query,
})
}
}
}
// Re-insert default worker settings
_, err = tx.ExecContext(ctx, `
INSERT INTO worker_settings (setting_key, setting_value, created_at, updated_at)
VALUES ('global_pause', 'false', NOW(), NOW())
ON CONFLICT (setting_key) DO NOTHING;
`)
if err != nil {
if logger != nil {
logger.Error(ctx, "Failed to insert worker settings", err)
}
}
err = tx.Commit()
if err != nil {
if logger != nil {
logger.Error(ctx, "Failed to commit cleanup transaction", err)
}
}
}
// CleanupTestDatabase cleans up the database for integration tests
// This function can be used by any integration test that needs to clean up the database
// Optimized to use batched transactions for better performance
func CleanupTestDatabase(db *sql.DB, t *testing.T) {
cleanupDatabase(db, nil)
}
package services
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"github.com/lib/pq"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"golang.org/x/crypto/bcrypt"
)
// UserServiceInterface defines the interface for user-related operations.
// This allows for easier mocking in tests.
type UserServiceInterface interface {
CreateUser(ctx context.Context, username, language, level string) (*models.User, error)
CreateUserWithPassword(ctx context.Context, username, password, language, level string) (*models.User, error)
CreateUserWithEmailAndTimezone(ctx context.Context, username, email, timezone, language, level string) (*models.User, error)
GetUserByID(ctx context.Context, id int) (*models.User, error)
GetUserByUsername(ctx context.Context, username string) (*models.User, error)
GetUserByEmail(ctx context.Context, email string) (*models.User, error)
AuthenticateUser(ctx context.Context, username, password string) (*models.User, error)
UpdateUserSettings(ctx context.Context, userID int, settings *models.UserSettings) error
UpdateUserProfile(ctx context.Context, userID int, username, email, timezone string) error
UpdateUserPassword(ctx context.Context, userID int, newPassword string) error
UpdateLastActive(ctx context.Context, userID int) error
GetAllUsers(ctx context.Context) ([]models.User, error)
GetUsersPaginated(ctx context.Context, page, pageSize int, search, language, level, aiProvider, aiModel, aiEnabled, active string) ([]models.User, int, error)
DeleteUser(ctx context.Context, userID int) error
DeleteAllUsers(ctx context.Context) error
EnsureAdminUserExists(ctx context.Context, adminUsername, adminPassword string) error
ResetDatabase(ctx context.Context) error
ClearUserData(ctx context.Context) error
ClearUserDataForUser(ctx context.Context, userID int) error
GetUserAPIKey(ctx context.Context, userID int, provider string) (string, error)
SetUserAPIKey(ctx context.Context, userID int, provider, apiKey string) error
HasUserAPIKey(ctx context.Context, userID int, provider string) (bool, error)
// Role management methods
GetUserRoles(ctx context.Context, userID int) ([]models.Role, error)
GetAllRoles(ctx context.Context) ([]models.Role, error)
AssignRole(ctx context.Context, userID, roleID int) error
AssignRoleByName(ctx context.Context, userID int, roleName string) error
RemoveRole(ctx context.Context, userID, roleID int) error
HasRole(ctx context.Context, userID int, roleName string) (bool, error)
IsAdmin(ctx context.Context, userID int) (bool, error)
GetDB() *sql.DB
}
// UserService provides methods for user management.
type UserService struct {
db *sql.DB
cfg *config.Config
logger *observability.Logger
}
// Shared query constants to eliminate duplication
const (
// userSelectFields contains all user fields for SELECT queries
userSelectFields = `id, username, email, timezone, password_hash, last_active, preferred_language, current_level, ai_provider, ai_model, ai_enabled, ai_api_key, created_at, updated_at`
// userSelectFieldsNoPassword contains user fields excluding password_hash for GetAllUsers
userSelectFieldsNoPassword = `id, username, email, timezone, last_active, preferred_language, current_level, ai_provider, ai_model, ai_enabled, ai_api_key, created_at, updated_at`
)
// scanUserFromRow scans a database row into a models.User struct
func (s *UserService) scanUserFromRow(row *sql.Row) (result0 *models.User, err error) {
user := &models.User{}
err = row.Scan(
&user.ID, &user.Username, &user.Email, &user.Timezone, &user.PasswordHash, &user.LastActive,
&user.PreferredLanguage, &user.CurrentLevel, &user.AIProvider,
&user.AIModel, &user.AIEnabled, &user.AIAPIKey, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
return nil, err
}
return user, nil
}
// scanUserFromRowsNoPassword scans a database rows into a models.User struct (without password_hash)
func (s *UserService) scanUserFromRowsNoPassword(rows *sql.Rows) (result0 *models.User, err error) {
user := &models.User{}
err = rows.Scan(
&user.ID, &user.Username, &user.Email, &user.Timezone, &user.LastActive,
&user.PreferredLanguage, &user.CurrentLevel, &user.AIProvider,
&user.AIModel, &user.AIEnabled, &user.AIAPIKey, &user.CreatedAt, &user.UpdatedAt,
)
if err != nil {
return nil, err
}
return user, nil
}
// getUserByQuery is a shared method for getting a user by any query
func (s *UserService) getUserByQuery(ctx context.Context, query string, args ...interface{}) (result0 *models.User, err error) {
row := s.db.QueryRowContext(ctx, query, args...)
var user *models.User
user, err = s.scanUserFromRow(row)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil // User not found is not an error here
}
return nil, err
}
// Try to apply default settings, but don't fail if there's an issue
s.applyDefaultSettings(ctx, user)
return user, nil
}
// NewUserServiceWithLogger creates a new UserService instance with logger
func NewUserServiceWithLogger(db *sql.DB, cfg *config.Config, logger *observability.Logger) *UserService {
return &UserService{
db: db,
cfg: cfg,
logger: logger,
}
}
// CreateUser creates a new user with the specified username, language, and level
func (s *UserService) CreateUser(ctx context.Context, username, language, level string) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "create_user", attribute.String("user.username", username))
defer observability.FinishSpan(span, &err)
// Validate username is not empty
if username == "" || len(strings.TrimSpace(username)) == 0 {
return nil, contextutils.WrapError(contextutils.ErrInvalidInput, "username cannot be empty")
}
// default timezone to UTC for new users
query := `INSERT INTO users (username, preferred_language, current_level, last_active, created_at, updated_at, timezone) VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING id`
now := time.Now()
var id int
err = s.db.QueryRowContext(ctx, query, username, language, level, now, now, now, "UTC").Scan(&id)
if err != nil {
return nil, err
}
var user *models.User
user, err = s.GetUserByID(ctx, id)
if err != nil {
return nil, err
}
if user == nil {
return nil, contextutils.WrapError(contextutils.ErrDatabaseQuery, "user was created but could not be retrieved from database")
}
return user, nil
}
// CreateUserWithEmailAndTimezone creates a new user with email and timezone
func (s *UserService) CreateUserWithEmailAndTimezone(ctx context.Context, username, email, timezone, language, level string) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "create_user_with_email", attribute.String("user.username", username))
defer observability.FinishSpan(span, &err)
// Validate username is not empty
if username == "" || len(strings.TrimSpace(username)) == 0 {
return nil, contextutils.WrapError(contextutils.ErrInvalidInput, "username cannot be empty")
}
query := `INSERT INTO users (username, email, timezone, preferred_language, current_level, ai_enabled, last_active, created_at, updated_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id`
now := time.Now()
var id int
err = s.db.QueryRowContext(ctx, query, username, email, timezone, language, level, false, now, now, now).Scan(&id)
if err != nil {
if isDuplicateKeyError(err) {
return nil, contextutils.ErrRecordExists
}
return nil, err
}
if err != nil {
return nil, err
}
var user *models.User
user, err = s.GetUserByID(ctx, id)
if err != nil {
return nil, err
}
if user == nil {
return nil, contextutils.WrapError(contextutils.ErrDatabaseQuery, "user was created but could not be retrieved from database")
}
return user, nil
}
// CreateUserWithPassword creates a new user with password authentication
func (s *UserService) CreateUserWithPassword(ctx context.Context, username, password, language, level string) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "create_user_with_password", attribute.String("user.username", username))
defer observability.FinishSpan(span, &err)
// Validate username is not empty
if username == "" || len(strings.TrimSpace(username)) == 0 {
return nil, contextutils.WrapError(contextutils.ErrInvalidInput, "username cannot be empty")
}
// Hash the password using bcrypt
var hashedPassword []byte
hashedPassword, err = bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
// default timezone to UTC for new users created with password
query := `INSERT INTO users (username, password_hash, preferred_language, current_level, ai_enabled, last_active, created_at, updated_at, timezone) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING id`
now := time.Now()
var id int
err = s.db.QueryRowContext(ctx, query, username, string(hashedPassword), language, level, false, now, now, now, "UTC").Scan(&id)
if err != nil {
if isDuplicateKeyError(err) {
return nil, contextutils.ErrRecordExists
}
return nil, err
}
if err != nil {
return nil, err
}
user, err := s.GetUserByID(ctx, id)
if err != nil {
return nil, err
}
if user == nil {
return nil, contextutils.WrapError(contextutils.ErrDatabaseQuery, "user was created but could not be retrieved from database")
}
return user, nil
}
// AuthenticateUser verifies user credentials and returns the user if valid
func (s *UserService) AuthenticateUser(ctx context.Context, username, password string) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "authenticate_user", attribute.String("user.username", username))
defer observability.FinishSpan(span, &err)
// Get user by username
var user *models.User
user, err = s.GetUserByUsername(ctx, username)
if err != nil {
return nil, err
}
if user == nil {
return nil, errors.New("user not found")
}
// Check if password hash exists
if !user.PasswordHash.Valid {
return nil, errors.New("user has no password set")
}
// Compare provided password with stored hash
err = bcrypt.CompareHashAndPassword([]byte(user.PasswordHash.String), []byte(password))
if err != nil {
return nil, errors.New("invalid password")
}
return user, nil
}
// GetUserByID retrieves a user by their ID
func (s *UserService) GetUserByID(ctx context.Context, id int) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_user_by_id", attribute.Int("user.id", id))
defer observability.FinishSpan(span, &err)
query := fmt.Sprintf("SELECT %s FROM users WHERE id = $1", userSelectFields)
var user *models.User
user, err = s.getUserByQuery(ctx, query, id)
if err != nil {
s.logger.Error(ctx, "Database error retrieving user", err, map[string]interface{}{"user_id": id})
return nil, err
}
if user == nil {
s.logger.Debug(ctx, "User not found in database", map[string]interface{}{"user_id": id})
return nil, nil
}
// Load user roles
roles, err := s.GetUserRoles(ctx, id)
if err != nil {
s.logger.Warn(ctx, "Failed to load user roles", map[string]interface{}{"user_id": id, "error": err.Error()})
// Don't fail the entire request if roles can't be loaded
user.Roles = []models.Role{}
} else {
user.Roles = roles
}
return user, nil
}
// GetUserByUsername retrieves a user by their username
func (s *UserService) GetUserByUsername(ctx context.Context, username string) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_user_by_username", attribute.String("user.username", username))
defer observability.FinishSpan(span, &err)
query := fmt.Sprintf("SELECT %s FROM users WHERE username = $1", userSelectFields)
return s.getUserByQuery(ctx, query, username)
}
// UpdateUserSettings updates user settings including AI configuration
func (s *UserService) UpdateUserSettings(ctx context.Context, userID int, settings *models.UserSettings) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "update_user_settings", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
// Check if user exists before updating settings
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to check if user exists")
}
if user == nil {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
// Start a transaction to update both user settings and API key
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction for user settings update")
}
defer func() {
if rollbackErr := tx.Rollback(); rollbackErr != nil && rollbackErr != sql.ErrTxDone {
s.logger.Warn(ctx, "Warning: failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}()
// Handle AI enabled logic
aiProvider := settings.AIProvider
aiModel := settings.AIModel
// If AI is disabled, clear the provider and model
if !settings.AIEnabled {
aiProvider = ""
aiModel = ""
}
// Update user settings (excluding API key which is now stored separately)
query := `UPDATE users SET preferred_language = $1, current_level = $2, ai_provider = $3, ai_model = $4, ai_enabled = $5, updated_at = $6 WHERE id = $7`
var result sql.Result
result, err = tx.ExecContext(ctx, query, settings.Language, settings.Level, aiProvider, aiModel, settings.AIEnabled, time.Now(), userID)
if err != nil {
return contextutils.WrapError(err, "failed to update user settings in transaction")
}
// Check if the user was actually updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "user with ID %d not found", userID)
}
// If an API key is provided and AI is enabled, save it for the specific provider
if settings.AIAPIKey != "" && settings.AIProvider != "" && settings.AIEnabled {
err = s.setUserAPIKeyTx(ctx, tx, userID, settings.AIProvider, settings.AIAPIKey)
if err != nil {
return contextutils.WrapError(err, "failed to set user API key in transaction")
}
}
return tx.Commit()
}
// GetUserAPIKey retrieves the API key for a specific provider for a user
func (s *UserService) GetUserAPIKey(ctx context.Context, userID int, provider string) (result0 string, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_user_api_key", attribute.Int("user.id", userID), attribute.String("user.provider", provider))
defer observability.FinishSpan(span, &err)
// Check if user exists before getting API key
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return "", contextutils.WrapError(err, "failed to check if user exists")
}
if user == nil {
return "", contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
query := `SELECT api_key FROM user_api_keys WHERE user_id = $1 AND provider = $2`
var apiKey string
err = s.db.QueryRowContext(ctx, query, userID, provider).Scan(&apiKey)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", contextutils.WrapError(contextutils.ErrRecordNotFound, "API key for provider not found")
}
return "", contextutils.WrapError(err, "failed to get user API key")
}
return apiKey, nil
}
// SetUserAPIKey sets the API key for a specific provider for a user
func (s *UserService) SetUserAPIKey(ctx context.Context, userID int, provider, apiKey string) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "set_user_api_key", attribute.Int("user.id", userID), attribute.String("user.provider", provider))
defer observability.FinishSpan(span, &err)
// Check if user exists before setting API key
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to check if user exists")
}
if user == nil {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction for API key update")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn(ctx, "Warning: failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}
}()
err = s.setUserAPIKeyTx(ctx, tx, userID, provider, apiKey)
if err != nil {
return contextutils.WrapError(err, "failed to set user API key in transaction")
}
commitErr := tx.Commit()
if commitErr != nil {
return contextutils.WrapError(commitErr, "failed to commit API key transaction")
}
// Clear the error so defer doesn't try to rollback
err = nil
return nil
}
// setUserAPIKeyTx sets the API key for a specific provider within a transaction
func (s *UserService) setUserAPIKeyTx(ctx context.Context, tx *sql.Tx, userID int, provider, apiKey string) error {
query := `INSERT INTO user_api_keys (user_id, provider, api_key, updated_at)
VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, provider)
DO UPDATE SET api_key = $3, updated_at = $4`
_, err := tx.ExecContext(ctx, query, userID, provider, apiKey, time.Now())
return contextutils.WrapError(err, "failed to execute API key transaction")
}
// HasUserAPIKey checks if a user has an API key for a specific provider
func (s *UserService) HasUserAPIKey(ctx context.Context, userID int, provider string) (result0 bool, err error) {
ctx, span := observability.TraceUserFunction(ctx, "has_user_api_key", attribute.Int("user.id", userID), attribute.String("user.provider", provider))
defer observability.FinishSpan(span, &err)
var apiKey string
apiKey, err = s.GetUserAPIKey(ctx, userID, provider)
if err != nil {
// If the error is "not found" and it's specifically about the API key not existing (not the user),
// then it means no API key exists, which is not an error
if errors.Is(err, contextutils.ErrRecordNotFound) {
// Check if the error message indicates it's about the API key, not the user
if strings.Contains(err.Error(), "API key for provider not found") {
return false, nil
}
// If it's about the user not found, return the error
return false, err
}
return false, contextutils.WrapError(err, "failed to check if user has API key")
}
return apiKey != "", nil
}
// UpdateLastActive updates the user's last activity timestamp
func (s *UserService) UpdateLastActive(ctx context.Context, userID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "update_last_active", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
query := `UPDATE users SET last_active = $1 WHERE id = $2`
var result sql.Result
result, err = s.db.ExecContext(ctx, query, time.Now(), userID)
if err != nil {
return contextutils.WrapError(err, "failed to update user last active timestamp")
}
// Check if the user was actually updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "user with ID %d not found", userID)
}
return nil
}
// GetAllUsers retrieves all users from the database
func (s *UserService) GetAllUsers(ctx context.Context) (result0 []models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_all_users")
defer observability.FinishSpan(span, &err)
query := fmt.Sprintf("SELECT %s FROM users", userSelectFieldsNoPassword)
var rows *sql.Rows
rows, err = s.db.QueryContext(ctx, query)
if err != nil {
return nil, contextutils.WrapError(err, "failed to query all users")
}
defer func() {
if err = rows.Close(); err != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": err.Error()})
}
}()
var users []models.User
for rows.Next() {
user, err := s.scanUserFromRowsNoPassword(rows)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan user from rows")
}
// Load user roles
roles, err := s.GetUserRoles(ctx, user.ID)
if err != nil {
s.logger.Warn(ctx, "Failed to load user roles", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
// Don't fail the entire request if roles can't be loaded
user.Roles = []models.Role{}
} else {
user.Roles = roles
}
users = append(users, *user)
}
return users, nil
}
// GetUsersPaginated retrieves paginated users with filtering and search
func (s *UserService) GetUsersPaginated(ctx context.Context, page, pageSize int, search, language, level, aiProvider, aiModel, aiEnabled, active string) (result0 []models.User, result1 int, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_users_paginated")
defer observability.FinishSpan(span, &err)
// Build WHERE clause and args
var conditions []string
var args []interface{}
argIndex := 1
// Search filter
if search != "" {
conditions = append(conditions, fmt.Sprintf("(username ILIKE $%d OR email ILIKE $%d)", argIndex, argIndex))
args = append(args, "%"+search+"%")
argIndex++
}
// Language filter
if language != "" {
conditions = append(conditions, fmt.Sprintf("preferred_language = $%d", argIndex))
args = append(args, language)
argIndex++
}
// Level filter
if level != "" {
conditions = append(conditions, fmt.Sprintf("current_level = $%d", argIndex))
args = append(args, level)
argIndex++
}
// AI Provider filter
if aiProvider != "" {
conditions = append(conditions, fmt.Sprintf("ai_provider = $%d", argIndex))
args = append(args, aiProvider)
argIndex++
}
// AI Model filter
if aiModel != "" {
conditions = append(conditions, fmt.Sprintf("ai_model = $%d", argIndex))
args = append(args, aiModel)
argIndex++
}
// AI Enabled filter
if aiEnabled != "" {
enabled := aiEnabled == "true"
conditions = append(conditions, fmt.Sprintf("ai_enabled = $%d", argIndex))
args = append(args, enabled)
argIndex++
}
// Active filter (based on last_active within 7 days)
if active != "" {
activeThreshold := time.Now().AddDate(0, 0, -7)
switch active {
case "true":
conditions = append(conditions, fmt.Sprintf("last_active >= $%d", argIndex))
args = append(args, activeThreshold)
case "false":
conditions = append(conditions, fmt.Sprintf("(last_active < $%d OR last_active IS NULL)", argIndex))
args = append(args, activeThreshold)
}
argIndex++
}
// Build WHERE clause
whereClause := ""
if len(conditions) > 0 {
whereClause = "WHERE " + strings.Join(conditions, " AND ")
}
// Get total count
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM users %s", whereClause)
var total int
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&total)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to count users")
}
// Get paginated results
offset := (page - 1) * pageSize
query := fmt.Sprintf("SELECT %s FROM users %s ORDER BY username LIMIT $%d OFFSET $%d",
userSelectFieldsNoPassword, whereClause, argIndex, argIndex+1)
args = append(args, pageSize, offset)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to query paginated users")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
var users []models.User
for rows.Next() {
user, err := s.scanUserFromRowsNoPassword(rows)
if err != nil {
return nil, 0, contextutils.WrapError(err, "failed to scan user from rows")
}
// Load user roles
roles, err := s.GetUserRoles(ctx, user.ID)
if err != nil {
s.logger.Warn(ctx, "Failed to load user roles", map[string]interface{}{"user_id": user.ID, "error": err.Error()})
// Don't fail the entire request if roles can't be loaded
user.Roles = []models.Role{}
} else {
user.Roles = roles
}
users = append(users, *user)
}
return users, total, nil
}
// GetUserByEmail retrieves a user by their email address
func (s *UserService) GetUserByEmail(ctx context.Context, email string) (result0 *models.User, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_user_by_email", attribute.String("user.email", email))
defer observability.FinishSpan(span, &err)
query := fmt.Sprintf("SELECT %s FROM users WHERE email = $1", userSelectFields)
return s.getUserByQuery(ctx, query, email)
}
// UpdateUserProfile updates user profile information (username, email, timezone)
func (s *UserService) UpdateUserProfile(ctx context.Context, userID int, username, email, timezone string) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "update_user_profile", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
query := `UPDATE users SET username = $1, email = $2, timezone = $3, updated_at = $4 WHERE id = $5`
var result sql.Result
result, err = s.db.ExecContext(ctx, query, username, email, timezone, time.Now(), userID)
if err != nil {
return contextutils.WrapError(err, "failed to update user profile")
}
// Check if the user was actually updated
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapErrorf(contextutils.ErrRecordNotFound, "user with ID %d not found", userID)
}
return nil
}
// UpdateUserPassword updates a user's password
func (s *UserService) UpdateUserPassword(ctx context.Context, userID int, newPassword string) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "update_user_password", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
// Validate password is not empty
if newPassword == "" {
return contextutils.ErrorWithContextf("password cannot be empty")
}
// Check if user exists first
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to check if user exists")
}
if user == nil {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
// Hash the new password using bcrypt
var hashedPassword []byte
hashedPassword, err = bcrypt.GenerateFromPassword([]byte(newPassword), bcrypt.DefaultCost)
if err != nil {
return contextutils.WrapError(err, "failed to hash password")
}
query := `UPDATE users SET password_hash = $1, updated_at = $2 WHERE id = $3`
result, err := s.db.ExecContext(ctx, query, string(hashedPassword), time.Now(), userID)
if err != nil {
return contextutils.WrapError(err, "failed to update user password")
}
// Check if any rows were affected
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
s.logger.Info(ctx, "Password updated successfully", map[string]interface{}{"user_id": userID, "username": user.Username})
return nil
}
// DeleteUser removes a user and their associated data
func (s *UserService) DeleteUser(ctx context.Context, userID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "delete_user", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
// Check if user exists before deleting
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to check if user exists")
}
if user == nil {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
// Best-effort cleanup of dependent rows for tables that may not have ON DELETE CASCADE in some environments
// This keeps tests deterministic and avoids orphaned data
// TODO: This is a hack to make the tests deterministic. We should use ON DELETE CASCADE instead.
cleanupQueries := []string{
`DELETE FROM question_reports WHERE reported_by_user_id = $1`,
`DELETE FROM user_api_keys WHERE user_id = $1`,
`DELETE FROM user_roles WHERE user_id = $1`,
`DELETE FROM user_learning_preferences WHERE user_id = $1`,
`DELETE FROM question_priority_scores WHERE user_id = $1`,
`DELETE FROM user_question_metadata WHERE user_id = $1`,
`DELETE FROM user_responses WHERE user_id = $1`,
`DELETE FROM user_questions WHERE user_id = $1`,
}
for _, q := range cleanupQueries {
if _, err := s.db.ExecContext(ctx, q, userID); err != nil {
s.logger.Warn(ctx, "Non-fatal cleanup failure during user delete", map[string]interface{}{"error": err.Error(), "query": q, "user_id": userID})
}
}
// Delete the user
query := `DELETE FROM users WHERE id = $1`
result, err := s.db.ExecContext(ctx, query, userID)
if err != nil {
return contextutils.WrapError(err, "failed to delete user")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.WrapError(contextutils.ErrRecordNotFound, "user not found")
}
s.logger.Info(ctx, "User %d deleted successfully", map[string]interface{}{"user_id": userID})
return nil
}
// DeleteAllUsers removes all users from the database
func (s *UserService) DeleteAllUsers(ctx context.Context) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "delete_all_users")
defer observability.FinishSpan(span, &err)
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction for delete all users")
}
defer func() {
if err != nil {
if rollbackErr := tx.Rollback(); rollbackErr != nil {
s.logger.Warn(ctx, "Warning: failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}
}()
// Whitelist of valid table names to prevent SQL injection
validTables := map[string]bool{
"user_responses": true,
"performance_metrics": true,
"users": true,
}
// Delete all data in the correct order (to respect foreign key constraints)
tables := []string{
"user_responses",
"performance_metrics",
"users",
}
for _, table := range tables {
// Validate table name against whitelist
if !validTables[table] {
return contextutils.ErrorWithContextf("invalid table name: %s", table)
}
// Use parameterized query with validated table name
query := fmt.Sprintf("DELETE FROM %s", table)
if _, err := tx.ExecContext(ctx, query); err != nil {
return contextutils.WrapErrorf(err, "failed to delete from table %s", table)
}
// Reset sequence for PostgreSQL
sequenceQuery := fmt.Sprintf("ALTER SEQUENCE %s_id_seq RESTART WITH 1", table)
if _, err := tx.ExecContext(ctx, sequenceQuery); err != nil {
// This might fail if the table doesn't have a sequence, so we log but don't fail
s.logger.Warn(ctx, "Note: Could not reset sequence for %s (this is normal for some tables)", map[string]interface{}{"table": table})
}
}
return contextutils.WrapError(tx.Commit(), "failed to commit delete all users transaction")
}
// EnsureAdminUserExists creates the admin user if it doesn't exist
func (s *UserService) EnsureAdminUserExists(ctx context.Context, adminUsername, adminPassword string) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "ensure_admin_user_exists", attribute.String("admin.username", adminUsername))
defer observability.FinishSpan(span, &err)
// Validate input parameters
if adminUsername == "" {
return contextutils.ErrorWithContextf("admin username cannot be empty")
}
if adminPassword == "" {
return contextutils.ErrorWithContextf("admin password cannot be empty")
}
// Check if admin user already exists
var existingUser *models.User
existingUser, err = s.GetUserByUsername(ctx, adminUsername)
if err != nil {
return contextutils.WrapError(err, "failed to check if admin user exists")
}
if existingUser != nil {
// User exists, check if password needs to be updated
if existingUser.PasswordHash.Valid {
// User has a password, test if it matches current admin password
err = bcrypt.CompareHashAndPassword([]byte(existingUser.PasswordHash.String), []byte(adminPassword))
if err == nil {
// Password matches, ensure AI settings are configured
err = s.ensureAdminAISettings(ctx, existingUser.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to set AI settings for existing admin user", map[string]interface{}{"error": err.Error()})
}
// Ensure admin user has email and timezone if not set
if !existingUser.Email.Valid || !existingUser.Timezone.Valid {
err = s.ensureAdminProfile(ctx, existingUser.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to update admin profile", map[string]interface{}{"error": err.Error()})
}
}
// Ensure admin user has admin role
isAdmin, err := s.IsAdmin(ctx, existingUser.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to check admin role for existing admin user", map[string]interface{}{"error": err.Error()})
} else if !isAdmin {
err = s.AssignRoleByName(ctx, existingUser.ID, "admin")
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to assign admin role to existing admin user", map[string]interface{}{"error": err.Error()})
}
}
s.logger.Info(ctx, "Admin user already exists with correct password", map[string]interface{}{"username": adminUsername})
return nil
}
}
// Update password
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(adminPassword), bcrypt.DefaultCost)
if err != nil {
return contextutils.WrapError(err, "failed to hash admin password")
}
query := `UPDATE users SET password_hash = $1, updated_at = $2 WHERE username = $3`
_, err = s.db.ExecContext(ctx, query, string(hashedPassword), time.Now(), adminUsername)
if err != nil {
return contextutils.WrapError(err, "failed to update admin user password")
}
// Ensure AI settings are configured
err = s.ensureAdminAISettings(ctx, existingUser.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to set AI settings for existing admin user", map[string]interface{}{"error": err.Error()})
}
// Ensure admin user has email and timezone if not set
if !existingUser.Email.Valid || !existingUser.Timezone.Valid {
err = s.ensureAdminProfile(ctx, existingUser.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to update admin profile", map[string]interface{}{"error": err.Error()})
}
}
// Ensure admin user has admin role
isAdmin, err := s.IsAdmin(ctx, existingUser.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to check admin role for existing admin user", map[string]interface{}{"error": err.Error()})
} else if !isAdmin {
err = s.AssignRoleByName(ctx, existingUser.ID, "admin")
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to assign admin role to existing admin user", map[string]interface{}{"error": err.Error()})
}
}
s.logger.Info(ctx, "Updated password for admin user", map[string]interface{}{"username": adminUsername})
return nil
}
// Create new admin user with email and timezone
user, err := s.CreateUserWithEmailAndTimezone(ctx, adminUsername, "admin@example.com", "America/New_York", "italian", "A1")
if err != nil {
return contextutils.WrapError(err, "failed to create admin user")
}
// Set password for the admin user
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(adminPassword), bcrypt.DefaultCost)
if err != nil {
return contextutils.WrapError(err, "failed to hash new admin password")
}
query := `UPDATE users SET password_hash = $1, updated_at = $2 WHERE id = $3`
_, err = s.db.ExecContext(ctx, query, string(hashedPassword), time.Now(), user.ID)
if err != nil {
return contextutils.WrapError(err, "failed to set password for new admin user")
}
// Set up AI settings for the admin user
err = s.ensureAdminAISettings(ctx, user.ID)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to set AI settings for new admin user", map[string]interface{}{"error": err.Error()})
}
// Assign admin role to the admin user
err = s.AssignRoleByName(ctx, user.ID, "admin")
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to assign admin role to new admin user", map[string]interface{}{"error": err.Error()})
}
s.logger.Info(ctx, "Created admin user", map[string]interface{}{"username": adminUsername})
return nil
}
// ensureAdminAISettings ensures the admin user has AI settings configured
// Only sets default values if the user doesn't already have AI settings configured
func (s *UserService) ensureAdminAISettings(ctx context.Context, userID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "ensure_admin_ai_settings", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
var user *models.User
user, err = s.GetUserByID(ctx, userID)
if err != nil {
return err
}
if user == nil {
return errors.New("admin user not found")
}
// If user already has AI provider configured, don't override their settings
if user.AIProvider.Valid && user.AIProvider.String != "" {
s.logger.Info(ctx, "User ID already has AI settings configured, preserving existing settings", map[string]interface{}{"user_id": userID, "provider": user.AIProvider.String})
return nil
}
// Set default AI settings with a default API key
settings := &models.UserSettings{
AIProvider: "ollama",
AIModel: "llama4:latest",
AIAPIKey: "not_needed", // Default API key
}
// Only update AI settings, preserve other user settings
query := `UPDATE users SET ai_provider = $1, ai_model = $2, ai_api_key = $3, updated_at = $4 WHERE id = $5`
_, err = s.db.ExecContext(ctx, query, settings.AIProvider, settings.AIModel, settings.AIAPIKey, time.Now(), userID)
if err != nil {
return contextutils.WrapError(err, "failed to update user AI settings")
}
// Save the API key to the user_api_keys table
err = s.SetUserAPIKey(ctx, userID, settings.AIProvider, settings.AIAPIKey)
if err != nil {
s.logger.Warn(ctx, "Warning: Failed to save API key for user %d", map[string]interface{}{"user_id": userID, "error": err.Error()})
}
s.logger.Info(ctx, "Set default AI settings for user", map[string]interface{}{"user_id": userID, "provider": settings.AIProvider, "model": settings.AIModel})
return nil
}
// ensureAdminProfile ensures the admin user has email and timezone set
func (s *UserService) ensureAdminProfile(ctx context.Context, userID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "ensure_admin_profile", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
query := `UPDATE users SET email = $1, timezone = $2, updated_at = $3 WHERE id = $4 AND (email IS NULL OR timezone IS NULL)`
_, err = s.db.ExecContext(ctx, query, "admin@example.com", "America/New_York", time.Now(), userID)
if err != nil {
return contextutils.WrapError(err, "failed to update admin profile")
}
s.logger.Info(ctx, "Updated admin user profile with default email and timezone", map[string]interface{}{"user_id": userID})
return nil
}
// ResetDatabase completely resets the database to an empty state
func (s *UserService) ResetDatabase(ctx context.Context) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "reset_database")
defer observability.FinishSpan(span, &err)
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction for database reset")
}
defer func() {
if rollbackErr := tx.Rollback(); rollbackErr != nil && rollbackErr != sql.ErrTxDone {
s.logger.Warn(ctx, "Warning: failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}()
// Whitelist of valid table names to prevent SQL injection
validTables := map[string]bool{
"user_responses": true,
"performance_metrics": true,
"questions": true,
"users": true,
}
// Delete all data in the correct order (to respect foreign key constraints)
tables := []string{
"user_responses",
"performance_metrics",
"questions",
"users",
}
for _, table := range tables {
// Validate table name against whitelist
if !validTables[table] {
return contextutils.ErrorWithContextf("invalid table name: %s", table)
}
// Use parameterized query with validated table name
query := fmt.Sprintf("DELETE FROM %s", table)
if _, err := tx.ExecContext(ctx, query); err != nil {
return contextutils.WrapErrorf(err, "failed to delete from table %s during reset", table)
}
s.logger.Info(ctx, "Cleared table: %s", map[string]interface{}{"table": table})
// Reset sequence for PostgreSQL
sequenceQuery := fmt.Sprintf("ALTER SEQUENCE %s_id_seq RESTART WITH 1", table)
if _, err := tx.ExecContext(ctx, sequenceQuery); err != nil {
// This might fail if the table doesn't have a sequence, so we log but don't fail
s.logger.Warn(ctx, "Note: Could not reset sequence for %s (this is normal for some tables)", map[string]interface{}{"table": table})
}
}
err = tx.Commit()
if err != nil {
return contextutils.WrapError(err, "failed to commit database reset transaction")
}
s.logger.Info(ctx, "Database reset completed successfully")
return nil
}
// ClearUserData removes all user activity data but keeps the users themselves
func (s *UserService) ClearUserData(ctx context.Context) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "clear_user_data")
defer observability.FinishSpan(span, &err)
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
return contextutils.WrapError(err, "failed to begin transaction for clear user data")
}
defer func() {
if rollbackErr := tx.Rollback(); rollbackErr != nil && rollbackErr != sql.ErrTxDone {
s.logger.Warn(ctx, "Warning: failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}()
// Whitelist of valid table names to prevent SQL injection
validTables := map[string]bool{
"user_responses": true,
"performance_metrics": true,
"questions": true,
}
// Delete user data but keep users (order matters due to foreign key constraints)
tables := []string{
"user_responses",
"performance_metrics",
"questions",
}
for _, table := range tables {
// Validate table name against whitelist
if !validTables[table] {
return contextutils.ErrorWithContextf("invalid table name: %s", table)
}
// Use parameterized query with validated table name
query := fmt.Sprintf("DELETE FROM %s", table)
if _, err := tx.ExecContext(ctx, query); err != nil {
return contextutils.WrapErrorf(err, "failed to delete from table %s during clear user data", table)
}
s.logger.Info(ctx, "Cleared table: %s", map[string]interface{}{"table": table})
// Reset sequence for PostgreSQL
sequenceQuery := fmt.Sprintf("ALTER SEQUENCE %s_id_seq RESTART WITH 1", table)
if _, err := tx.ExecContext(ctx, sequenceQuery); err != nil {
// This might fail if the table doesn't have a sequence, so we log but don't fail
s.logger.Warn(ctx, "Note: Could not reset sequence for %s (this is normal for some tables)", map[string]interface{}{"table": table})
}
}
err = tx.Commit()
if err != nil {
return contextutils.WrapError(err, "failed to commit clear user data transaction")
}
s.logger.Info(ctx, "User data cleared successfully (users preserved)")
return nil
}
// ClearUserDataForUser removes all user activity data for a specific user but keeps the user record
func (s *UserService) ClearUserDataForUser(ctx context.Context, userID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "clear_user_data_for_user", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
var tx *sql.Tx
tx, err = s.db.Begin()
if err != nil {
s.logger.Warn(ctx, "Failed to begin transaction", map[string]interface{}{"error": err.Error()})
return contextutils.WrapError(err, "failed to begin transaction for clear user data for specific user")
}
defer func() {
if rollbackErr := tx.Rollback(); rollbackErr != nil && rollbackErr != sql.ErrTxDone {
s.logger.Warn(ctx, "Warning: failed to rollback transaction", map[string]interface{}{"error": rollbackErr.Error()})
}
}()
// Delete user_responses for this user's questions (via user_questions)
query := `DELETE FROM user_responses WHERE question_id IN (SELECT question_id FROM user_questions WHERE user_id = $1)`
result, err := tx.ExecContext(ctx, query, userID)
if err != nil {
s.logger.Warn(ctx, "Failed to delete user_responses", map[string]interface{}{"error": err.Error()})
return contextutils.WrapError(err, "failed to delete user responses for specific user")
}
rows, _ := result.RowsAffected()
s.logger.Info(ctx, "Deleted %d user_responses for user %d", map[string]interface{}{"count": rows, "user_id": userID})
// Delete performance_metrics for this user (performance_metrics has user_id, not question_id)
query = `DELETE FROM performance_metrics WHERE user_id = $1`
result, err = tx.ExecContext(ctx, query, userID)
if err != nil {
s.logger.Warn(ctx, "Failed to delete performance_metrics", map[string]interface{}{"error": err.Error()})
return contextutils.WrapError(err, "failed to delete performance metrics for specific user")
}
rows, _ = result.RowsAffected()
s.logger.Info(ctx, "Deleted %d performance_metrics for user %d", map[string]interface{}{"count": rows, "user_id": userID})
// Delete user_questions for this user
query = `DELETE FROM user_questions WHERE user_id = $1`
result, err = tx.ExecContext(ctx, query, userID)
if err != nil {
s.logger.Warn(ctx, "Failed to delete user_questions", map[string]interface{}{"error": err.Error()})
return contextutils.WrapError(err, "failed to delete user questions for specific user")
}
rows, _ = result.RowsAffected()
s.logger.Info(ctx, "Deleted %d user_questions for user %d", map[string]interface{}{"count": rows, "user_id": userID})
// Optionally, delete orphaned questions (not assigned to any user)
query = `DELETE FROM questions WHERE id NOT IN (SELECT question_id FROM user_questions)`
result, err = tx.ExecContext(ctx, query)
if err != nil {
s.logger.Warn(ctx, "Failed to delete orphaned questions", map[string]interface{}{"error": err.Error()})
return contextutils.WrapError(err, "failed to delete orphaned questions")
}
rows, _ = result.RowsAffected()
s.logger.Info(ctx, "Deleted %d orphaned questions", map[string]interface{}{"count": rows})
if err := tx.Commit(); err != nil {
s.logger.Warn(ctx, "Failed to commit transaction", map[string]interface{}{"error": err.Error()})
return contextutils.WrapError(err, "failed to commit clear user data for specific user transaction")
}
s.logger.Info(ctx, "User data cleared successfully for user %d (users preserved)", map[string]interface{}{"user_id": userID})
return nil
}
func (s *UserService) applyDefaultSettings(ctx context.Context, user *models.User) {
if user == nil || s.cfg == nil {
return
}
_, span := observability.TraceUserFunction(ctx, "apply_default_settings", attribute.Int("user.id", user.ID))
defer span.End()
if user.AIProvider.String == "" && len(s.cfg.Providers) > 0 {
// Use the first available provider as default
provider := s.cfg.Providers[0]
user.AIProvider.String = provider.Code
// Use first model in the list as default
if len(provider.Models) > 0 {
user.AIModel.String = provider.Models[0].Code
}
}
if user.CurrentLevel.String == "" {
// Set default level based on user's preferred language, or use first available language
language := user.PreferredLanguage.String
if language == "" {
languages := s.cfg.GetLanguages()
if len(languages) > 0 {
language = languages[0]
}
}
if language != "" {
levels := s.cfg.GetLevelsForLanguage(language)
if len(levels) > 0 {
user.CurrentLevel.String = levels[0]
}
}
}
if user.PreferredLanguage.String == "" {
user.PreferredLanguage.String = "english"
}
}
// GetUserRoles retrieves all roles for a user
func (s *UserService) GetUserRoles(ctx context.Context, userID int) (result0 []models.Role, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_user_roles", attribute.Int("user.id", userID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT r.id, r.name, r.description, r.created_at, r.updated_at
FROM roles r
JOIN user_roles ur ON r.id = ur.role_id
WHERE ur.user_id = $1
ORDER BY r.name
`
rows, err := s.db.QueryContext(ctx, query, userID)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get user roles")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
var roles []models.Role
for rows.Next() {
var role models.Role
err := rows.Scan(&role.ID, &role.Name, &role.Description, &role.CreatedAt, &role.UpdatedAt)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan user role")
}
roles = append(roles, role)
}
if err = rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating user roles")
}
return roles, nil
}
// AssignRole assigns a role to a user
func (s *UserService) AssignRole(ctx context.Context, userID, roleID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "assign_role", attribute.Int("user.id", userID), attribute.Int("role.id", roleID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if user exists
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to get user for role assignment")
}
if user == nil {
return contextutils.ErrorWithContextf("user with ID %d not found", userID)
}
// Check if role exists
var roleName string
err = s.db.QueryRowContext(ctx, "SELECT name FROM roles WHERE id = $1", roleID).Scan(&roleName)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return contextutils.ErrorWithContextf("role with ID %d not found", roleID)
}
return contextutils.WrapError(err, "failed to check role existence")
}
// Assign role (using ON CONFLICT DO NOTHING to handle duplicate assignments gracefully)
query := `INSERT INTO user_roles (user_id, role_id, created_at) VALUES ($1, $2, $3) ON CONFLICT (user_id, role_id) DO NOTHING`
_, err = s.db.ExecContext(ctx, query, userID, roleID, time.Now())
if err != nil {
return contextutils.WrapError(err, "failed to assign role to user")
}
s.logger.Info(ctx, "Role assigned successfully", map[string]interface{}{
"user_id": userID,
"role_id": roleID,
"role_name": roleName,
})
return nil
}
// AssignRoleByName assigns a role to a user by role name
func (s *UserService) AssignRoleByName(ctx context.Context, userID int, roleName string) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "assign_role_by_name", attribute.Int("user.id", userID), attribute.String("role.name", roleName))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if user exists
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to get user for role assignment")
}
if user == nil {
return contextutils.ErrorWithContextf("user with ID %d not found", userID)
}
// Get role ID by name
var roleID int
err = s.db.QueryRowContext(ctx, "SELECT id FROM roles WHERE name = $1", roleName).Scan(&roleID)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return contextutils.ErrorWithContextf("role with name '%s' not found", roleName)
}
return contextutils.WrapError(err, "failed to get role ID by name")
}
// Assign role (using ON CONFLICT DO NOTHING to handle duplicate assignments gracefully)
query := `INSERT INTO user_roles (user_id, role_id, created_at) VALUES ($1, $2, $3) ON CONFLICT (user_id, role_id) DO NOTHING`
_, err = s.db.ExecContext(ctx, query, userID, roleID, time.Now())
if err != nil {
return contextutils.WrapError(err, "failed to assign role to user")
}
s.logger.Info(ctx, "Role assigned successfully", map[string]interface{}{
"user_id": userID,
"role_id": roleID,
"role_name": roleName,
})
return nil
}
// RemoveRole removes a role from a user
func (s *UserService) RemoveRole(ctx context.Context, userID, roleID int) (err error) {
ctx, span := observability.TraceUserFunction(ctx, "remove_role", attribute.Int("user.id", userID), attribute.Int("role.id", roleID))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
// Check if user exists
user, err := s.GetUserByID(ctx, userID)
if err != nil {
return contextutils.WrapError(err, "failed to get user for role removal")
}
if user == nil {
return contextutils.ErrorWithContextf("user with ID %d not found", userID)
}
// Check if role exists
var roleName string
err = s.db.QueryRowContext(ctx, "SELECT name FROM roles WHERE id = $1", roleID).Scan(&roleName)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return contextutils.ErrorWithContextf("role with ID %d not found", roleID)
}
return contextutils.WrapError(err, "failed to check role existence")
}
// Remove role
query := `DELETE FROM user_roles WHERE user_id = $1 AND role_id = $2`
result, err := s.db.ExecContext(ctx, query, userID, roleID)
if err != nil {
return contextutils.WrapError(err, "failed to remove role from user")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return contextutils.WrapError(err, "failed to get rows affected")
}
if rowsAffected == 0 {
return contextutils.ErrorWithContextf("user %d does not have role %d", userID, roleID)
}
s.logger.Info(ctx, "Role removed successfully", map[string]interface{}{
"user_id": userID,
"role_id": roleID,
"role_name": roleName,
})
return nil
}
// HasRole checks if a user has a specific role by name
func (s *UserService) HasRole(ctx context.Context, userID int, roleName string) (result0 bool, err error) {
ctx, span := observability.TraceUserFunction(ctx, "has_role", attribute.Int("user.id", userID), attribute.String("role.name", roleName))
defer func() {
if err != nil {
span.RecordError(err, trace.WithStackTrace(true))
span.SetStatus(codes.Error, err.Error())
}
span.End()
}()
query := `
SELECT COUNT(*) > 0
FROM user_roles ur
JOIN roles r ON ur.role_id = r.id
WHERE ur.user_id = $1 AND r.name = $2
`
var hasRole bool
err = s.db.QueryRowContext(ctx, query, userID, roleName).Scan(&hasRole)
if err != nil {
return false, contextutils.WrapError(err, "failed to check if user has role")
}
return hasRole, nil
}
// IsAdmin checks if a user has admin role
func (s *UserService) IsAdmin(ctx context.Context, userID int) (result0 bool, err error) {
ctx, span := observability.TraceUserFunction(ctx, "is_admin", attribute.Int("user.id", userID))
defer observability.FinishSpan(span, &err)
return s.HasRole(ctx, userID, "admin")
}
// GetAllRoles returns all available roles in the system
func (s *UserService) GetAllRoles(ctx context.Context) (result0 []models.Role, err error) {
ctx, span := observability.TraceUserFunction(ctx, "get_all_roles")
defer observability.FinishSpan(span, &err)
query := `
SELECT id, name, description, created_at, updated_at
FROM roles
ORDER BY name
`
rows, err := s.db.QueryContext(ctx, query)
if err != nil {
return nil, contextutils.WrapError(err, "failed to get all roles")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Warn(ctx, "Warning: failed to close rows", map[string]interface{}{"error": closeErr.Error()})
}
}()
var roles []models.Role
for rows.Next() {
var role models.Role
err := rows.Scan(&role.ID, &role.Name, &role.Description, &role.CreatedAt, &role.UpdatedAt)
if err != nil {
return nil, contextutils.WrapError(err, "failed to scan role")
}
roles = append(roles, role)
}
if err = rows.Err(); err != nil {
return nil, contextutils.WrapError(err, "error iterating roles")
}
return roles, nil
}
// GetDB returns the database connection
func (s *UserService) GetDB() *sql.DB {
return s.db
}
// isDuplicateKeyError checks if the error is a duplicate key constraint violation
func isDuplicateKeyError(err error) bool {
if err == nil {
return false
}
// Check for PostgreSQL unique constraint violation error code
if pqErr, ok := err.(*pq.Error); ok {
// PostgreSQL error code 23505 is for unique constraint violations
if pqErr.Code == "23505" {
return true
}
}
return false
}
package services
import (
"context"
"math/rand"
"go.opentelemetry.io/otel/attribute"
"quizapp/internal/config"
"quizapp/internal/observability"
)
// VarietyService handles the selection of variety elements for question generation
type VarietyService struct {
cfg *config.Config
logger *observability.Logger
}
// VarietyElements holds the randomly selected variety elements for a question generation request
type VarietyElements struct {
TopicCategory string
GrammarFocus string
VocabularyDomain string
Scenario string
StyleModifier string
DifficultyModifier string
TimeContext string
}
// NewVarietyServiceWithLogger creates a new VarietyService with logger
func NewVarietyServiceWithLogger(cfg *config.Config, logger *observability.Logger) *VarietyService {
return &VarietyService{
cfg: cfg,
logger: logger,
}
}
// SelectVarietyElements randomly selects variety elements for question generation
// If highPriorityTopics or userWeakAreas are provided, bias topic selection toward those topics first, then gapAnalysis.
func (vs *VarietyService) SelectVarietyElements(ctx context.Context, level string, highPriorityTopics, userWeakAreas []string, gapAnalysis map[string]int) *VarietyElements {
_, span := observability.TraceVarietyFunction(ctx, "select_variety_elements",
attribute.String("variety.level", level),
attribute.Int("variety.high_priority_topics_count", len(highPriorityTopics)),
attribute.Int("variety.user_weak_areas_count", len(userWeakAreas)),
attribute.Int("variety.gap_analysis_count", len(gapAnalysis)),
)
defer span.End()
// Get variety configuration from config
if vs.cfg.Variety != nil {
variety := vs.cfg.Variety
elements := &VarietyElements{}
// Helper function to get weighted selection from gap analysis
getWeightedSelection := func(gapType string, availableOptions []string) string {
if len(gapAnalysis) == 0 || len(availableOptions) == 0 {
return ""
}
var weightedOptions []string
for _, option := range availableOptions {
gapKey := gapType + "_" + option
if count, ok := gapAnalysis[gapKey]; ok && count > 0 {
// Intensify weighting by squaring the severity to reduce randomness sensitivity
weight := count * count
for range weight {
weightedOptions = append(weightedOptions, option)
}
}
}
if len(weightedOptions) > 0 {
return weightedOptions[rand.Intn(len(weightedOptions))]
}
return ""
}
// Define all possible variety elements with their selection functions
type varietySelector struct {
name string
selector func() string
}
var selectors []varietySelector
// Topic category selector (biased by userWeakAreas, highPriorityTopics, then gapAnalysis if provided)
if len(variety.TopicCategories) > 0 {
selectors = append(selectors, varietySelector{
name: "topic_category",
selector: func() string {
// 1. UserWeakAreas
if len(userWeakAreas) > 0 {
var matching []string
for _, topic := range variety.TopicCategories {
for _, weak := range userWeakAreas {
if topic == weak {
matching = append(matching, topic)
}
}
}
if len(matching) > 0 {
elements.TopicCategory = matching[rand.Intn(len(matching))]
return elements.TopicCategory
}
}
// 2. HighPriorityTopics
if len(highPriorityTopics) > 0 {
var matching []string
for _, topic := range variety.TopicCategories {
for _, high := range highPriorityTopics {
if topic == high {
matching = append(matching, topic)
}
}
}
if len(matching) > 0 {
elements.TopicCategory = matching[rand.Intn(len(matching))]
return elements.TopicCategory
}
}
// 3. GapAnalysis for topics
if selected := getWeightedSelection("topic_category", variety.TopicCategories); selected != "" {
elements.TopicCategory = selected
return elements.TopicCategory
}
// Fallback to random
elements.TopicCategory = variety.TopicCategories[rand.Intn(len(variety.TopicCategories))]
return elements.TopicCategory
},
})
}
// Grammar focus selector (now with gap analysis support)
if grammarByLevel, exists := variety.GrammarFocusByLevel[level]; exists && len(grammarByLevel) > 0 {
selectors = append(selectors, varietySelector{
name: "grammar_focus",
selector: func() string {
// Check for grammar gaps first
if selected := getWeightedSelection("grammar_focus", grammarByLevel); selected != "" {
elements.GrammarFocus = selected
return elements.GrammarFocus
}
// Fallback to random
elements.GrammarFocus = grammarByLevel[rand.Intn(len(grammarByLevel))]
return elements.GrammarFocus
},
})
} else if len(variety.GrammarFocus) > 0 {
selectors = append(selectors, varietySelector{
name: "grammar_focus",
selector: func() string {
// Check for grammar gaps first
if selected := getWeightedSelection("grammar_focus", variety.GrammarFocus); selected != "" {
elements.GrammarFocus = selected
return elements.GrammarFocus
}
// Fallback to random
elements.GrammarFocus = variety.GrammarFocus[rand.Intn(len(variety.GrammarFocus))]
return elements.GrammarFocus
},
})
}
// Vocabulary domain selector (now with gap analysis support)
if len(variety.VocabularyDomains) > 0 {
selectors = append(selectors, varietySelector{
name: "vocabulary_domain",
selector: func() string {
// Check for vocabulary gaps first
if selected := getWeightedSelection("vocabulary_domain", variety.VocabularyDomains); selected != "" {
elements.VocabularyDomain = selected
return elements.VocabularyDomain
}
// Fallback to random
elements.VocabularyDomain = variety.VocabularyDomains[rand.Intn(len(variety.VocabularyDomains))]
return elements.VocabularyDomain
},
})
}
// Scenario selector (now with gap analysis support)
if len(variety.Scenarios) > 0 {
selectors = append(selectors, varietySelector{
name: "scenario",
selector: func() string {
// Check for scenario gaps first
if selected := getWeightedSelection("scenario", variety.Scenarios); selected != "" {
elements.Scenario = selected
return elements.Scenario
}
// Fallback to random
elements.Scenario = variety.Scenarios[rand.Intn(len(variety.Scenarios))]
return elements.Scenario
},
})
}
// Style modifier selector
if len(variety.StyleModifiers) > 0 {
selectors = append(selectors, varietySelector{
name: "style_modifier",
selector: func() string {
elements.StyleModifier = variety.StyleModifiers[rand.Intn(len(variety.StyleModifiers))]
return elements.StyleModifier
},
})
}
// Difficulty modifier selector
if len(variety.DifficultyModifiers) > 0 {
selectors = append(selectors, varietySelector{
name: "difficulty_modifier",
selector: func() string {
elements.DifficultyModifier = variety.DifficultyModifiers[rand.Intn(len(variety.DifficultyModifiers))]
return elements.DifficultyModifier
},
})
}
// Time context selector
if len(variety.TimeContexts) > 0 {
selectors = append(selectors, varietySelector{
name: "time_context",
selector: func() string {
elements.TimeContext = variety.TimeContexts[rand.Intn(len(variety.TimeContexts))]
return elements.TimeContext
},
})
}
// Randomly select 2-3 variety elements (instead of all 7)
numToSelect := 2
if len(selectors) > 2 {
// 70% chance of 2 elements, 30% chance of 3 elements
if rand.Float64() < 0.3 {
numToSelect = 3
}
}
// Shuffle and select the first numToSelect elements
rand.Shuffle(len(selectors), func(i, j int) {
selectors[i], selectors[j] = selectors[j], selectors[i]
})
// Apply the selected variety elements
for i := 0; i < numToSelect && i < len(selectors); i++ {
selected := selectors[i].selector()
span.SetAttributes(attribute.String("variety."+selectors[i].name, selected))
}
span.SetAttributes(
attribute.String("variety.topic_category", elements.TopicCategory),
attribute.String("variety.grammar_focus", elements.GrammarFocus),
attribute.String("variety.vocabulary_domain", elements.VocabularyDomain),
attribute.String("variety.scenario", elements.Scenario),
attribute.String("variety.style_modifier", elements.StyleModifier),
attribute.String("variety.difficulty_modifier", elements.DifficultyModifier),
attribute.String("variety.time_context", elements.TimeContext),
attribute.Int("variety.elements_selected", numToSelect),
)
span.SetAttributes(attribute.String("variety.result", "success"))
return elements
}
span.SetAttributes(attribute.String("variety.result", "no_config"))
return &VarietyElements{} // Return empty if no variety config
}
// SelectMultipleVarietyElements selects multiple sets of variety elements for batch generation
func (vs *VarietyService) SelectMultipleVarietyElements(ctx context.Context, level string, count int) []*VarietyElements {
ctx, span := observability.TraceVarietyFunction(ctx, "select_multiple_variety_elements",
attribute.String("variety.level", level),
attribute.Int("variety.count", count),
)
defer span.End()
elements := make([]*VarietyElements, count)
for i := 0; i < count; i++ {
elements[i] = vs.SelectVarietyElements(ctx, level, nil, nil, nil)
}
span.SetAttributes(attribute.String("variety.result", "success"), attribute.Int("variety.elements_count", len(elements)))
return elements
}
package services
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
"quizapp/internal/models"
"quizapp/internal/observability"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel/attribute"
)
// ErrSettingNotFound is returned when a setting is not found in the database
var ErrSettingNotFound = errors.New("setting not found")
// WorkerServiceInterface defines the interface for worker management operations
type WorkerServiceInterface interface {
// Settings management
GetSetting(ctx context.Context, key string) (string, error)
SetSetting(ctx context.Context, key, value string) error
IsGlobalPaused(ctx context.Context) (bool, error)
SetGlobalPause(ctx context.Context, paused bool) error
IsUserPaused(ctx context.Context, userID int) (bool, error)
SetUserPause(ctx context.Context, userID int, paused bool) error
// Status management
UpdateWorkerStatus(ctx context.Context, instance string, status *models.WorkerStatus) error
GetWorkerStatus(ctx context.Context, instance string) (*models.WorkerStatus, error)
GetAllWorkerStatuses(ctx context.Context) ([]models.WorkerStatus, error)
UpdateHeartbeat(ctx context.Context, instance string) error
IsWorkerHealthy(ctx context.Context, instance string) (bool, error)
// Control operations
PauseWorker(ctx context.Context, instance string) error
ResumeWorker(ctx context.Context, instance string) error
GetWorkerHealth(ctx context.Context) (map[string]interface{}, error)
GetHighPriorityTopics(ctx context.Context, userID int, language, level, questionType string) ([]string, error)
GetGapAnalysis(ctx context.Context, userID int, language, level, questionType string) (map[string]int, error)
GetPriorityDistribution(ctx context.Context, userID int, language, level, questionType string) (map[string]int, error)
// Notification management
GetNotificationStats(ctx context.Context) (map[string]interface{}, error)
GetNotificationErrors(ctx context.Context, page, pageSize int, errorType, notificationType, resolved string) ([]map[string]interface{}, map[string]interface{}, map[string]interface{}, error)
GetUpcomingNotifications(ctx context.Context, page, pageSize int, notificationType, status, scheduledAfter, scheduledBefore string) ([]map[string]interface{}, map[string]interface{}, map[string]interface{}, error)
GetSentNotifications(ctx context.Context, page, pageSize int, notificationType, status, sentAfter, sentBefore string) ([]map[string]interface{}, map[string]interface{}, map[string]interface{}, error)
// Test methods for creating test data
CreateTestSentNotification(ctx context.Context, userID int, notificationType, subject, templateName, status, errorMessage string) error
}
// WorkerService implements worker management operations
type WorkerService struct {
db *sql.DB
logger *observability.Logger
}
// NewWorkerServiceWithLogger creates a new WorkerService instance with logger
func NewWorkerServiceWithLogger(db *sql.DB, logger *observability.Logger) *WorkerService {
return &WorkerService{
db: db,
logger: logger,
}
}
// GetSetting retrieves a setting value by key
func (s *WorkerService) GetSetting(ctx context.Context, key string) (result0 string, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_setting", attribute.String("setting.key", key))
defer observability.FinishSpan(span, &err)
// Validate key
if len(key) == 0 || len(strings.TrimSpace(key)) == 0 {
return "", contextutils.WrapErrorf(errors.New("invalid setting key"), "setting key cannot be empty")
}
var value string
err = s.db.QueryRowContext(ctx, `
SELECT setting_value FROM worker_settings WHERE setting_key = $1
`, key).Scan(&value)
if err != nil {
if err == sql.ErrNoRows {
s.logger.Debug(ctx, "Setting not found", map[string]interface{}{"setting_key": key})
return "", contextutils.WrapErrorf(ErrSettingNotFound, "%s", key)
}
s.logger.Error(ctx, "Failed to get setting", err, map[string]interface{}{"setting_key": key})
return "", contextutils.WrapErrorf(err, "failed to get setting %s", key)
}
return value, nil
}
// SetSetting updates or creates a setting
func (s *WorkerService) SetSetting(ctx context.Context, key, value string) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "set_setting", attribute.String("setting.key", key))
defer observability.FinishSpan(span, &err)
// Validate key
if len(key) == 0 || len(strings.TrimSpace(key)) == 0 {
return contextutils.WrapErrorf(errors.New("invalid setting key"), "setting key cannot be empty")
}
_, err = s.db.ExecContext(ctx, `
INSERT INTO worker_settings (setting_key, setting_value, updated_at)
VALUES ($1, $2, NOW())
ON CONFLICT (setting_key) DO UPDATE SET
setting_value = EXCLUDED.setting_value,
updated_at = EXCLUDED.updated_at
`, key, value)
if err != nil {
s.logger.Error(ctx, "Failed to set setting", err, map[string]interface{}{"setting_key": key, "setting_value": value})
return contextutils.WrapErrorf(err, "failed to set setting %s", key)
}
s.logger.Debug(ctx, "Setting updated", map[string]interface{}{"setting_key": key, "setting_value": value})
return nil
}
// IsGlobalPaused checks if the worker is globally paused
func (s *WorkerService) IsGlobalPaused(ctx context.Context) (result0 bool, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "is_global_paused")
defer observability.FinishSpan(span, &err)
var value string
value, err = s.GetSetting(ctx, "global_pause")
if err != nil {
// If setting doesn't exist, default to false (not paused)
if errors.Is(err, ErrSettingNotFound) {
// Initialize the setting with default value
if setErr := s.SetSetting(ctx, "global_pause", "false"); setErr != nil {
s.logger.Error(ctx, "Failed to initialize global_pause setting", setErr, map[string]interface{}{})
return false, contextutils.WrapError(setErr, "failed to initialize global_pause setting")
}
return false, nil
}
s.logger.Error(ctx, "Failed to check global pause status", err, map[string]interface{}{})
return false, err
}
paused := value == "true"
s.logger.Debug(ctx, "Global pause status checked", map[string]interface{}{"global_paused": paused})
return paused, nil
}
// SetGlobalPause sets the global pause state
func (s *WorkerService) SetGlobalPause(ctx context.Context, paused bool) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "set_global_pause", attribute.Bool("paused", paused))
defer observability.FinishSpan(span, &err)
value := "false"
if paused {
value = "true"
}
err = s.SetSetting(ctx, "global_pause", value)
if err != nil {
return err
}
s.logger.Info(ctx, "Global pause state updated", map[string]interface{}{"global_paused": paused})
return nil
}
// IsUserPaused checks if a specific user is paused
func (s *WorkerService) IsUserPaused(ctx context.Context, userID int) (result0 bool, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "is_user_paused", observability.AttributeUserID(userID))
defer observability.FinishSpan(span, &err)
key := fmt.Sprintf("user_pause_%d", userID)
var value string
err = s.db.QueryRowContext(ctx, `
SELECT setting_value FROM worker_settings WHERE setting_key = $1
`, key).Scan(&value)
if err != nil {
if err == sql.ErrNoRows {
// If setting doesn't exist, user is not paused (this is the default state)
s.logger.Debug(ctx, "User pause setting not found, defaulting to not paused", map[string]interface{}{"user_id": userID})
return false, nil
}
s.logger.Error(ctx, "Failed to check user pause status", err, map[string]interface{}{"user_id": userID})
return false, contextutils.WrapErrorf(err, "failed to check user pause status for user %d", userID)
}
paused := value == "true"
s.logger.Debug(ctx, "User pause status checked", map[string]interface{}{"user_id": userID, "user_paused": paused})
return paused, nil
}
// SetUserPause sets the pause state for a specific user
func (s *WorkerService) SetUserPause(ctx context.Context, userID int, paused bool) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "set_user_pause", observability.AttributeUserID(userID), attribute.Bool("paused", paused))
defer observability.FinishSpan(span, &err)
key := fmt.Sprintf("user_pause_%d", userID)
value := "false"
if paused {
value = "true"
}
err = s.SetSetting(ctx, key, value)
if err != nil {
return err
}
s.logger.Info(ctx, "User pause state updated", map[string]interface{}{"user_id": userID, "user_paused": paused})
return nil
}
// UpdateWorkerStatus updates the worker status in the database
func (s *WorkerService) UpdateWorkerStatus(ctx context.Context, instance string, status *models.WorkerStatus) (err error) {
activity := ""
if status.CurrentActivity.Valid {
activity = status.CurrentActivity.String
}
ctx, span := observability.TraceWorkerFunction(ctx, "update_worker_status",
attribute.String("worker.instance", instance),
attribute.Bool("worker.is_running", status.IsRunning),
attribute.Bool("worker.is_paused", status.IsPaused),
attribute.String("worker.activity", activity),
)
defer observability.FinishSpan(span, &err)
_, err = s.db.ExecContext(ctx, `
INSERT INTO worker_status (
worker_instance, is_running, is_paused, current_activity,
last_heartbeat, last_run_start, last_run_finish, last_run_error,
total_questions_generated, total_runs, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, NOW())
ON CONFLICT (worker_instance) DO UPDATE SET
is_running = EXCLUDED.is_running,
is_paused = EXCLUDED.is_paused,
current_activity = EXCLUDED.current_activity,
last_heartbeat = EXCLUDED.last_heartbeat,
last_run_start = EXCLUDED.last_run_start,
last_run_finish = EXCLUDED.last_run_finish,
last_run_error = EXCLUDED.last_run_error,
total_questions_generated = EXCLUDED.total_questions_generated,
total_runs = EXCLUDED.total_runs,
updated_at = EXCLUDED.updated_at
`, instance, status.IsRunning, status.IsPaused, status.CurrentActivity,
status.LastHeartbeat, status.LastRunStart, status.LastRunFinish,
status.LastRunError, status.TotalQuestionsGenerated, status.TotalRuns)
if err != nil {
s.logger.Error(ctx, "Failed to update worker status", err, map[string]interface{}{
"worker_instance": instance,
"is_running": status.IsRunning,
"is_paused": status.IsPaused,
"activity": activity,
})
err = contextutils.WrapErrorf(err, "failed to update worker status for instance %s", instance)
return err
}
s.logger.Debug(ctx, "Worker status updated", map[string]interface{}{
"worker_instance": instance,
"is_running": status.IsRunning,
"is_paused": status.IsPaused,
"activity": activity,
})
return nil
}
// GetWorkerStatus retrieves worker status by instance
func (s *WorkerService) GetWorkerStatus(ctx context.Context, instance string) (result0 *models.WorkerStatus, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_worker_status", attribute.String("worker.instance", instance))
defer observability.FinishSpan(span, &err)
var status models.WorkerStatus
err = s.db.QueryRowContext(ctx, `
SELECT id, worker_instance, is_running, is_paused, current_activity,
last_heartbeat, last_run_start, last_run_finish, last_run_error,
total_questions_generated, total_runs, created_at, updated_at
FROM worker_status WHERE worker_instance = $1
`, instance).Scan(
&status.ID, &status.WorkerInstance, &status.IsRunning, &status.IsPaused,
&status.CurrentActivity, &status.LastHeartbeat, &status.LastRunStart,
&status.LastRunFinish, &status.LastRunError, &status.TotalQuestionsGenerated,
&status.TotalRuns, &status.CreatedAt, &status.UpdatedAt,
)
if err != nil {
if err == sql.ErrNoRows {
s.logger.Debug(ctx, "Worker status not found", map[string]interface{}{"worker_instance": instance})
return nil, contextutils.WrapErrorf(err, "worker status not found for instance %s", instance)
}
s.logger.Error(ctx, "Failed to get worker status", err, map[string]interface{}{"worker_instance": instance})
return nil, contextutils.WrapErrorf(err, "failed to get worker status for instance %s", instance)
}
return &status, nil
}
// GetAllWorkerStatuses retrieves all worker statuses
func (s *WorkerService) GetAllWorkerStatuses(ctx context.Context) (result0 []models.WorkerStatus, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_all_worker_statuses")
defer observability.FinishSpan(span, &err)
var rows *sql.Rows
rows, err = s.db.QueryContext(ctx, `
SELECT id, worker_instance, is_running, is_paused, current_activity,
last_heartbeat, last_run_start, last_run_finish, last_run_error,
total_questions_generated, total_runs, created_at, updated_at
FROM worker_status ORDER BY worker_instance
`)
if err != nil {
s.logger.Error(ctx, "Failed to get all worker statuses", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get all worker statuses")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Error(ctx, "Failed to close rows", err, map[string]interface{}{})
}
}()
var statuses []models.WorkerStatus
for rows.Next() {
var status models.WorkerStatus
err = rows.Scan(
&status.ID, &status.WorkerInstance, &status.IsRunning, &status.IsPaused,
&status.CurrentActivity, &status.LastHeartbeat, &status.LastRunStart,
&status.LastRunFinish, &status.LastRunError, &status.TotalQuestionsGenerated,
&status.TotalRuns, &status.CreatedAt, &status.UpdatedAt,
)
if err != nil {
s.logger.Error(ctx, "Failed to scan worker status row", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to scan worker status row")
}
statuses = append(statuses, status)
}
if err := rows.Err(); err != nil {
s.logger.Error(ctx, "Error iterating worker status rows", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "error iterating worker status rows")
}
s.logger.Debug(ctx, "Retrieved all worker statuses", map[string]interface{}{"count": len(statuses)})
return statuses, nil
}
// UpdateHeartbeat updates the heartbeat for a worker instance
func (s *WorkerService) UpdateHeartbeat(ctx context.Context, instance string) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "update_heartbeat", attribute.String("worker.instance", instance))
defer observability.FinishSpan(span, &err)
_, err = s.db.ExecContext(ctx, `
INSERT INTO worker_status (worker_instance, last_heartbeat, updated_at)
VALUES ($1, NOW(), NOW())
ON CONFLICT (worker_instance) DO UPDATE SET
last_heartbeat = EXCLUDED.last_heartbeat,
updated_at = EXCLUDED.updated_at
`, instance)
if err != nil {
s.logger.Error(ctx, "Failed to update heartbeat", err, map[string]interface{}{"worker_instance": instance})
return contextutils.WrapErrorf(err, "failed to update heartbeat for instance %s", instance)
}
s.logger.Debug(ctx, "Heartbeat updated", map[string]interface{}{"worker_instance": instance})
return nil
}
// IsWorkerHealthy checks if a worker instance is healthy based on recent heartbeat
func (s *WorkerService) IsWorkerHealthy(ctx context.Context, instance string) (result0 bool, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "is_worker_healthy", attribute.String("worker.instance", instance))
defer observability.FinishSpan(span, &err)
var lastHeartbeat sql.NullTime
err = s.db.QueryRowContext(ctx, `
SELECT last_heartbeat FROM worker_status WHERE worker_instance = $1
`, instance).Scan(&lastHeartbeat)
if err != nil {
if err == sql.ErrNoRows {
s.logger.Debug(ctx, "Worker not found, considered unhealthy", map[string]interface{}{"worker_instance": instance})
return false, nil
}
s.logger.Error(ctx, "Failed to check worker health", err, map[string]interface{}{"worker_instance": instance})
return false, contextutils.WrapErrorf(err, "failed to check worker health for instance %s", instance)
}
if !lastHeartbeat.Valid {
s.logger.Debug(ctx, "Worker has no heartbeat, considered unhealthy", map[string]interface{}{"worker_instance": instance})
return false, nil
}
// Consider worker healthy if heartbeat is within the last 5 minutes
healthy := time.Since(lastHeartbeat.Time) < 5*time.Minute
s.logger.Debug(ctx, "Worker health checked", map[string]interface{}{
"worker_instance": instance,
"healthy": healthy,
"last_heartbeat": lastHeartbeat.Time,
"time_since": time.Since(lastHeartbeat.Time).String(),
})
return healthy, nil
}
// PauseWorker pauses a specific worker instance
func (s *WorkerService) PauseWorker(ctx context.Context, instance string) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "pause_worker", attribute.String("worker.instance", instance))
defer observability.FinishSpan(span, &err)
_, err = s.db.ExecContext(ctx, `
UPDATE worker_status SET is_paused = true, updated_at = NOW()
WHERE worker_instance = $1
`, instance)
if err != nil {
s.logger.Error(ctx, "Failed to pause worker", err, map[string]interface{}{"worker_instance": instance})
return contextutils.WrapErrorf(err, "failed to pause worker instance %s", instance)
}
s.logger.Info(ctx, "Worker paused", map[string]interface{}{"worker_instance": instance})
return nil
}
// ResumeWorker resumes a specific worker instance
func (s *WorkerService) ResumeWorker(ctx context.Context, instance string) (err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "resume_worker", attribute.String("worker.instance", instance))
defer observability.FinishSpan(span, &err)
_, err = s.db.ExecContext(ctx, `
UPDATE worker_status SET is_paused = false, updated_at = NOW()
WHERE worker_instance = $1
`, instance)
if err != nil {
s.logger.Error(ctx, "Failed to resume worker", err, map[string]interface{}{"worker_instance": instance})
return contextutils.WrapErrorf(err, "failed to resume worker instance %s", instance)
}
s.logger.Info(ctx, "Worker resumed", map[string]interface{}{"worker_instance": instance})
return nil
}
// GetWorkerHealth returns a map of worker health information
func (s *WorkerService) GetWorkerHealth(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_worker_health")
defer observability.FinishSpan(span, &err)
var statuses []models.WorkerStatus
statuses, err = s.GetAllWorkerStatuses(ctx)
if err != nil {
return nil, err
}
var globalPaused bool
globalPaused, err = s.IsGlobalPaused(ctx)
if err != nil {
s.logger.Error(ctx, "Failed to get global pause state", err, map[string]interface{}{})
globalPaused = false // Default to false if we can't get the state
}
health := make(map[string]interface{})
workerInstances := make([]map[string]interface{}, 0)
healthyCount := 0
totalCount := len(statuses)
for _, status := range statuses {
healthy, err := s.IsWorkerHealthy(ctx, status.WorkerInstance)
if err != nil {
s.logger.Error(ctx, "Failed to check health for worker", err, map[string]interface{}{"worker_instance": status.WorkerInstance})
continue
}
if healthy {
healthyCount++
}
workerInstance := map[string]interface{}{
"worker_instance": status.WorkerInstance,
"healthy": healthy,
"is_running": status.IsRunning,
"is_paused": status.IsPaused,
"last_heartbeat": status.LastHeartbeat,
"total_questions_generated": status.TotalQuestionsGenerated,
"total_runs": status.TotalRuns,
}
workerInstances = append(workerInstances, workerInstance)
}
// Build comprehensive health summary
health["global_paused"] = globalPaused
health["worker_instances"] = workerInstances
health["total_count"] = totalCount
health["healthy_count"] = healthyCount
s.logger.Debug(ctx, "Worker health retrieved", map[string]interface{}{"worker_count": len(health)})
return health, nil
}
// GetHighPriorityTopics returns topics with high average priority scores for a user
func (s *WorkerService) GetHighPriorityTopics(ctx context.Context, userID int, language, level, questionType string) (result0 []string, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_high_priority_topics",
observability.AttributeUserID(userID),
observability.AttributeLanguage(language),
observability.AttributeLevel(level),
attribute.String("question.type", questionType),
)
defer observability.FinishSpan(span, &err)
query := `
SELECT q.topic_category, AVG(qps.priority_score) as avg_score
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
JOIN question_priority_scores qps ON q.id = qps.question_id AND qps.user_id = $1
WHERE uq.user_id = $1
AND q.language = $2
AND q.level = $3
AND q.type = $4
AND q.topic_category IS NOT NULL
AND q.topic_category != ''
GROUP BY q.topic_category
HAVING AVG(qps.priority_score) >= 7.0
ORDER BY avg_score DESC
LIMIT 5
`
rows, err := s.db.QueryContext(ctx, query, userID, language, level, questionType)
if err != nil {
s.logger.Error(ctx, "Failed to get high priority topics", err, map[string]interface{}{
"user_id": userID, "language": language, "level": level, "question_type": questionType,
})
return nil, contextutils.WrapError(err, "failed to get high priority topics")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Error(ctx, "Failed to close rows", err, map[string]interface{}{})
}
}()
var topics []string
for rows.Next() {
var topic string
var avgScore float64
if err := rows.Scan(&topic, &avgScore); err != nil {
s.logger.Error(ctx, "Failed to scan high priority topics row", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to scan high priority topics row")
}
topics = append(topics, topic)
}
if err := rows.Err(); err != nil {
s.logger.Error(ctx, "Error iterating high priority topics rows", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "error iterating high priority topics rows")
}
s.logger.Debug(ctx, "Retrieved high priority topics", map[string]interface{}{"user_id": userID, "count": len(topics)})
return topics, nil
}
// GetGapAnalysis identifies areas with poor user performance (knowledge gaps)
func (s *WorkerService) GetGapAnalysis(ctx context.Context, userID int, language, level, questionType string) (result0 map[string]int, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_gap_analysis",
observability.AttributeUserID(userID),
observability.AttributeLanguage(language),
observability.AttributeLevel(level),
attribute.String("question.type", questionType),
)
defer observability.FinishSpan(span, &err)
// Query to find areas where user has poor performance (low accuracy)
// This analyzes gaps in user's knowledge across topics and varieties
query := `
WITH user_performance AS (
SELECT
q.topic_category,
q.grammar_focus,
q.vocabulary_domain,
q.scenario,
COUNT(*) as total_questions,
COUNT(CASE WHEN ur.is_correct = true THEN 1 END) as correct_answers,
ROUND(
COUNT(CASE WHEN ur.is_correct = true THEN 1 END)::decimal / COUNT(*)::decimal * 100, 2
) as accuracy_percentage
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
LEFT JOIN user_responses ur ON q.id = ur.question_id AND ur.user_id = $1
WHERE uq.user_id = $1
AND q.language = $2
AND q.level = $3
AND q.type = $4
GROUP BY q.topic_category, q.grammar_focus, q.vocabulary_domain, q.scenario
)
SELECT
COALESCE(topic_category, 'unknown') as area,
'topic' as gap_type,
total_questions,
accuracy_percentage
FROM user_performance
WHERE accuracy_percentage < 60 OR accuracy_percentage IS NULL
UNION ALL
SELECT
COALESCE(grammar_focus, 'unknown') as area,
'grammar' as gap_type,
total_questions,
accuracy_percentage
FROM user_performance
WHERE (accuracy_percentage < 60 OR accuracy_percentage IS NULL) AND grammar_focus IS NOT NULL
UNION ALL
SELECT
COALESCE(vocabulary_domain, 'unknown') as area,
'vocabulary' as gap_type,
total_questions,
accuracy_percentage
FROM user_performance
WHERE (accuracy_percentage < 60 OR accuracy_percentage IS NULL) AND vocabulary_domain IS NOT NULL
UNION ALL
SELECT
COALESCE(scenario, 'unknown') as area,
'scenario' as gap_type,
total_questions,
accuracy_percentage
FROM user_performance
WHERE (accuracy_percentage < 60 OR accuracy_percentage IS NULL) AND scenario IS NOT NULL
ORDER BY accuracy_percentage ASC, total_questions DESC
`
rows, err := s.db.QueryContext(ctx, query, userID, language, level, questionType)
if err != nil {
s.logger.Error(ctx, "Failed to get gap analysis", err, map[string]interface{}{
"user_id": userID, "language": language, "level": level, "question_type": questionType,
})
return nil, contextutils.WrapError(err, "failed to get gap analysis")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Error(ctx, "Failed to close rows", err, map[string]interface{}{})
}
}()
gaps := make(map[string]int)
for rows.Next() {
var area, gapType string
var totalQuestions int
var accuracyPercentage sql.NullFloat64
if err := rows.Scan(&area, &gapType, &totalQuestions, &accuracyPercentage); err != nil {
s.logger.Error(ctx, "Failed to scan gap analysis row", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to scan gap analysis row")
}
// Create a key that includes the gap type for better identification
key := fmt.Sprintf("%s_%s", gapType, area)
// Use the number of questions as the gap severity indicator
// Areas with more questions but poor performance are bigger gaps
gaps[key] = totalQuestions
}
if err := rows.Err(); err != nil {
s.logger.Error(ctx, "Error iterating gap analysis rows", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "error iterating gap analysis rows")
}
s.logger.Debug(ctx, "Retrieved gap analysis", map[string]interface{}{"user_id": userID, "count": len(gaps)})
return gaps, nil
}
// GetPriorityDistribution returns the distribution of priority scores by topic
func (s *WorkerService) GetPriorityDistribution(ctx context.Context, userID int, language, level, questionType string) (result0 map[string]int, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_priority_distribution",
observability.AttributeUserID(userID),
observability.AttributeLanguage(language),
observability.AttributeLevel(level),
attribute.String("question.type", questionType),
)
defer observability.FinishSpan(span, &err)
// Query to get priority score distribution by topic
query := `
SELECT q.topic_category, COUNT(*) as question_count
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
JOIN question_priority_scores qps ON q.id = qps.question_id AND qps.user_id = $1
WHERE uq.user_id = $1
AND q.language = $2
AND q.level = $3
AND q.type = $4
GROUP BY q.topic_category
`
rows, err := s.db.QueryContext(ctx, query, userID, language, level, questionType)
if err != nil {
s.logger.Error(ctx, "Failed to get priority distribution", err, map[string]interface{}{
"user_id": userID, "language": language, "level": level, "question_type": questionType,
})
return nil, contextutils.WrapError(err, "failed to get priority distribution")
}
defer func() {
if err := rows.Close(); err != nil {
s.logger.Error(ctx, "Failed to close rows", err, map[string]interface{}{})
}
}()
distribution := make(map[string]int)
for rows.Next() {
var topic string
var count int
if err := rows.Scan(&topic, &count); err != nil {
s.logger.Error(ctx, "Failed to scan priority distribution row", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to scan priority distribution row")
}
distribution[topic] = count
}
if err := rows.Err(); err != nil {
s.logger.Error(ctx, "Error iterating priority distribution rows", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "error iterating priority distribution rows")
}
s.logger.Debug(ctx, "Retrieved priority distribution", map[string]interface{}{"user_id": userID, "count": len(distribution)})
return distribution, nil
}
// GetNotificationStats returns comprehensive notification statistics
func (s *WorkerService) GetNotificationStats(ctx context.Context) (result0 map[string]interface{}, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_notification_stats")
defer observability.FinishSpan(span, &err)
// Get total notifications sent
var totalSent int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sent_notifications WHERE status = 'sent'
`).Scan(&totalSent)
if err != nil {
s.logger.Error(ctx, "Failed to get total notifications sent", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get total notifications sent")
}
// Get total notifications failed
var totalFailed int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sent_notifications WHERE status = 'failed'
`).Scan(&totalFailed)
if err != nil {
s.logger.Error(ctx, "Failed to get total notifications failed", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get total notifications failed")
}
// Calculate success rate
var successRate float64
if totalSent+totalFailed > 0 {
successRate = float64(totalSent) / float64(totalSent+totalFailed)
}
// Get users with notifications enabled
var usersWithNotifications int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(DISTINCT user_id) FROM user_learning_preferences WHERE daily_reminder_enabled = true
`).Scan(&usersWithNotifications)
if err != nil {
s.logger.Error(ctx, "Failed to get users with notifications enabled", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get users with notifications enabled")
}
// Get total users
var totalUsers int
err = s.db.QueryRowContext(ctx, `SELECT COUNT(*) FROM users`).Scan(&totalUsers)
if err != nil {
s.logger.Error(ctx, "Failed to get total users", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get total users")
}
// Get notifications sent today
var sentToday int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sent_notifications
WHERE status = 'sent' AND DATE(sent_at) = CURRENT_DATE
`).Scan(&sentToday)
if err != nil {
s.logger.Error(ctx, "Failed to get notifications sent today", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get notifications sent today")
}
// Get notifications sent this week
var sentThisWeek int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sent_notifications
WHERE status = 'sent' AND sent_at >= DATE_TRUNC('week', CURRENT_DATE)
`).Scan(&sentThisWeek)
if err != nil {
s.logger.Error(ctx, "Failed to get notifications sent this week", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get notifications sent this week")
}
// Get upcoming notifications
var upcomingNotifications int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM upcoming_notifications WHERE status = 'pending'
`).Scan(&upcomingNotifications)
if err != nil {
s.logger.Error(ctx, "Failed to get upcoming notifications", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get upcoming notifications")
}
// Get unresolved errors
var unresolvedErrors int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM notification_errors WHERE resolved_at IS NULL
`).Scan(&unresolvedErrors)
if err != nil {
s.logger.Error(ctx, "Failed to get unresolved errors", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get unresolved errors")
}
// Get notifications by type
notificationsByType := make(map[string]int)
rows, err := s.db.QueryContext(ctx, `
SELECT notification_type, COUNT(*)
FROM sent_notifications
WHERE status = 'sent'
GROUP BY notification_type
`)
if err != nil {
s.logger.Error(ctx, "Failed to get notifications by type", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get notifications by type")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
for rows.Next() {
var notificationType string
var count int
if err := rows.Scan(¬ificationType, &count); err != nil {
s.logger.Error(ctx, "Failed to scan notifications by type", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to scan notifications by type")
}
notificationsByType[notificationType] = count
}
// Get errors by type
errorsByType := make(map[string]int)
rows, err = s.db.QueryContext(ctx, `
SELECT error_type, COUNT(*)
FROM notification_errors
GROUP BY error_type
`)
if err != nil {
s.logger.Error(ctx, "Failed to get errors by type", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to get errors by type")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
for rows.Next() {
var errorType string
var count int
if err := rows.Scan(&errorType, &count); err != nil {
s.logger.Error(ctx, "Failed to scan errors by type", err, map[string]interface{}{})
return nil, contextutils.WrapError(err, "failed to scan errors by type")
}
errorsByType[errorType] = count
}
stats := map[string]interface{}{
"total_notifications_sent": totalSent,
"total_notifications_failed": totalFailed,
"success_rate": successRate,
"users_with_notifications_enabled": usersWithNotifications,
"total_users": totalUsers,
"notifications_sent_today": sentToday,
"notifications_sent_this_week": sentThisWeek,
"notifications_by_type": notificationsByType,
"errors_by_type": errorsByType,
"upcoming_notifications": upcomingNotifications,
"unresolved_errors": unresolvedErrors,
}
s.logger.Debug(ctx, "Retrieved notification stats", map[string]interface{}{"stats": stats})
return stats, nil
}
// GetNotificationErrors returns paginated notification errors with filtering
func (s *WorkerService) GetNotificationErrors(ctx context.Context, page, pageSize int, errorType, notificationType, resolved string) (result0 []map[string]interface{}, result1, result2 map[string]interface{}, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_notification_errors",
attribute.Int("page", page),
attribute.Int("page_size", pageSize),
attribute.String("error_type", errorType),
attribute.String("notification_type", notificationType),
attribute.String("resolved", resolved),
)
defer observability.FinishSpan(span, &err)
// Build WHERE clause
whereConditions := []string{}
args := []interface{}{}
argIndex := 1
if errorType != "" {
whereConditions = append(whereConditions, fmt.Sprintf("error_type = $%d", argIndex))
args = append(args, errorType)
argIndex++
}
if notificationType != "" {
whereConditions = append(whereConditions, fmt.Sprintf("notification_type = $%d", argIndex))
args = append(args, notificationType)
argIndex++
}
switch resolved {
case "true":
whereConditions = append(whereConditions, "resolved_at IS NOT NULL")
case "false":
whereConditions = append(whereConditions, "resolved_at IS NULL")
}
whereClause := ""
if len(whereConditions) > 0 {
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
}
// Get total count
var totalErrors int
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM notification_errors %s", whereClause)
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalErrors)
if err != nil {
s.logger.Error(ctx, "Failed to get total notification errors", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to get total notification errors")
}
// Calculate pagination
offset := (page - 1) * pageSize
totalPages := (totalErrors + pageSize - 1) / pageSize
// Get errors with pagination
args = append(args, pageSize, offset)
query := fmt.Sprintf(`
SELECT ne.id, ne.user_id, u.username, ne.notification_type, ne.error_type,
ne.error_message, ne.email_address, ne.occurred_at, ne.resolved_at, ne.resolution_notes
FROM notification_errors ne
LEFT JOIN users u ON ne.user_id = u.id
%s
ORDER BY ne.occurred_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIndex, argIndex+1)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
s.logger.Error(ctx, "Failed to get notification errors", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to get notification errors")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
var errors []map[string]interface{}
for rows.Next() {
var errorData map[string]interface{}
var id int
var userID sql.NullInt64
var username sql.NullString
var notificationType, errorType, errorMessage string
var emailAddress sql.NullString
var occurredAt time.Time
var resolvedAt sql.NullTime
var resolutionNotes sql.NullString
err := rows.Scan(&id, &userID, &username, ¬ificationType, &errorType, &errorMessage, &emailAddress, &occurredAt, &resolvedAt, &resolutionNotes)
if err != nil {
s.logger.Error(ctx, "Failed to scan notification error", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to scan notification error")
}
errorData = map[string]interface{}{
"id": id,
"notification_type": notificationType,
"error_type": errorType,
"error_message": errorMessage,
"occurred_at": occurredAt.Format(time.RFC3339),
}
if userID.Valid {
errorData["user_id"] = userID.Int64
}
if username.Valid {
errorData["username"] = username.String
}
if emailAddress.Valid {
errorData["email_address"] = emailAddress.String
}
if resolvedAt.Valid {
errorData["resolved_at"] = resolvedAt.Time.Format(time.RFC3339)
}
if resolutionNotes.Valid {
errorData["resolution_notes"] = resolutionNotes.String
}
errors = append(errors, errorData)
}
// Get stats
stats := map[string]interface{}{
"total_errors": totalErrors,
"unresolved_errors": 0, // Will be calculated separately
}
// Get unresolved errors count
var unresolvedCount int
err = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM notification_errors WHERE resolved_at IS NULL").Scan(&unresolvedCount)
if err != nil {
s.logger.Error(ctx, "Failed to get unresolved errors count", err, map[string]interface{}{})
} else {
stats["unresolved_errors"] = unresolvedCount
}
// Get errors by type
errorsByType := make(map[string]int)
rows, err = s.db.QueryContext(ctx, "SELECT error_type, COUNT(*) FROM notification_errors GROUP BY error_type")
if err != nil {
s.logger.Error(ctx, "Failed to get errors by type", err, map[string]interface{}{})
} else {
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
for rows.Next() {
var errorType string
var count int
if err := rows.Scan(&errorType, &count); err != nil {
s.logger.Error(ctx, "Failed to scan errors by type", err, map[string]interface{}{})
continue
}
errorsByType[errorType] = count
}
stats["errors_by_type"] = errorsByType
}
// Get errors by notification type
errorsByNotificationType := make(map[string]int)
rows, err = s.db.QueryContext(ctx, "SELECT notification_type, COUNT(*) FROM notification_errors GROUP BY notification_type")
if err != nil {
s.logger.Error(ctx, "Failed to get errors by notification type", err, map[string]interface{}{})
} else {
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
for rows.Next() {
var notificationType string
var count int
if err := rows.Scan(¬ificationType, &count); err != nil {
s.logger.Error(ctx, "Failed to scan errors by notification type", err, map[string]interface{}{})
continue
}
errorsByNotificationType[notificationType] = count
}
stats["errors_by_notification_type"] = errorsByNotificationType
}
pagination := map[string]interface{}{
"page": page,
"page_size": pageSize,
"total": totalErrors,
"total_pages": totalPages,
}
s.logger.Debug(ctx, "Retrieved notification errors", map[string]interface{}{
"count": len(errors), "page": page, "total": totalErrors,
})
return errors, pagination, stats, nil
}
// GetUpcomingNotifications returns paginated upcoming notifications with filtering
func (s *WorkerService) GetUpcomingNotifications(ctx context.Context, page, pageSize int, notificationType, status, scheduledAfter, scheduledBefore string) (result0 []map[string]interface{}, result1, result2 map[string]interface{}, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_upcoming_notifications",
attribute.Int("page", page),
attribute.Int("page_size", pageSize),
attribute.String("notification_type", notificationType),
attribute.String("status", status),
attribute.String("scheduled_after", scheduledAfter),
attribute.String("scheduled_before", scheduledBefore),
)
defer observability.FinishSpan(span, &err)
// Build WHERE clause
whereConditions := []string{}
args := []interface{}{}
argIndex := 1
if notificationType != "" {
whereConditions = append(whereConditions, fmt.Sprintf("notification_type = $%d", argIndex))
args = append(args, notificationType)
argIndex++
}
if status != "" {
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argIndex))
args = append(args, status)
argIndex++
}
if scheduledAfter != "" {
whereConditions = append(whereConditions, fmt.Sprintf("scheduled_for >= $%d", argIndex))
args = append(args, scheduledAfter)
argIndex++
}
if scheduledBefore != "" {
whereConditions = append(whereConditions, fmt.Sprintf("scheduled_for <= $%d", argIndex))
args = append(args, scheduledBefore)
argIndex++
}
whereClause := ""
if len(whereConditions) > 0 {
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
}
// Get total count
var totalNotifications int
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM upcoming_notifications %s", whereClause)
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalNotifications)
if err != nil {
s.logger.Error(ctx, "Failed to get total upcoming notifications", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to get total upcoming notifications")
}
// Calculate pagination
offset := (page - 1) * pageSize
totalPages := (totalNotifications + pageSize - 1) / pageSize
// Get notifications with pagination
args = append(args, pageSize, offset)
query := fmt.Sprintf(`
SELECT un.id, un.user_id, u.username, u.email, un.notification_type,
un.scheduled_for, un.status, un.created_at
FROM upcoming_notifications un
LEFT JOIN users u ON un.user_id = u.id
%s
ORDER BY un.scheduled_for ASC
LIMIT $%d OFFSET $%d
`, whereClause, argIndex, argIndex+1)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
s.logger.Error(ctx, "Failed to get upcoming notifications", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to get upcoming notifications")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
var notifications []map[string]interface{}
for rows.Next() {
var notification map[string]interface{}
var id, userID int
var username, notificationType, status string
var scheduledFor, createdAt time.Time
var email sql.NullString
err := rows.Scan(&id, &userID, &username, &email, ¬ificationType, &scheduledFor, &status, &createdAt)
if err != nil {
s.logger.Error(ctx, "Failed to scan upcoming notification", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to scan upcoming notification")
}
notification = map[string]interface{}{
"id": id,
"user_id": userID,
"username": username,
"notification_type": notificationType,
"scheduled_for": scheduledFor.Format(time.RFC3339),
"status": status,
"created_at": createdAt.Format(time.RFC3339),
}
if email.Valid {
notification["email_address"] = email.String
} else {
notification["email_address"] = ""
}
notifications = append(notifications, notification)
}
// Get stats
stats := map[string]interface{}{
"total_pending": 0,
"total_scheduled_today": 0,
"total_scheduled_this_week": 0,
}
// Get total pending
var totalPending int
err = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM upcoming_notifications WHERE status = 'pending'").Scan(&totalPending)
if err != nil {
s.logger.Error(ctx, "Failed to get total pending", err, map[string]interface{}{})
} else {
stats["total_pending"] = totalPending
}
// Get scheduled today
var scheduledToday int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM upcoming_notifications
WHERE status = 'pending' AND DATE(scheduled_for) = CURRENT_DATE
`).Scan(&scheduledToday)
if err != nil {
s.logger.Error(ctx, "Failed to get scheduled today", err, map[string]interface{}{})
} else {
stats["total_scheduled_today"] = scheduledToday
}
// Get scheduled this week
var scheduledThisWeek int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM upcoming_notifications
WHERE status = 'pending' AND scheduled_for >= DATE_TRUNC('week', CURRENT_DATE)
`).Scan(&scheduledThisWeek)
if err != nil {
s.logger.Error(ctx, "Failed to get scheduled this week", err, map[string]interface{}{})
} else {
stats["total_scheduled_this_week"] = scheduledThisWeek
}
// Get notifications by type
notificationsByType := make(map[string]int)
rows, err = s.db.QueryContext(ctx, "SELECT notification_type, COUNT(*) FROM upcoming_notifications GROUP BY notification_type")
if err != nil {
s.logger.Error(ctx, "Failed to get notifications by type", err, map[string]interface{}{})
} else {
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
for rows.Next() {
var notificationType string
var count int
if err := rows.Scan(¬ificationType, &count); err != nil {
s.logger.Error(ctx, "Failed to scan notifications by type", err, map[string]interface{}{})
continue
}
notificationsByType[notificationType] = count
}
stats["notifications_by_type"] = notificationsByType
}
pagination := map[string]interface{}{
"page": page,
"page_size": pageSize,
"total": totalNotifications,
"total_pages": totalPages,
}
s.logger.Debug(ctx, "Retrieved upcoming notifications", map[string]interface{}{
"count": len(notifications), "page": page, "total": totalNotifications,
})
return notifications, pagination, stats, nil
}
// GetSentNotifications returns paginated sent notifications with filtering
func (s *WorkerService) GetSentNotifications(ctx context.Context, page, pageSize int, notificationType, status, sentAfter, sentBefore string) (result0 []map[string]interface{}, result1, result2 map[string]interface{}, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_sent_notifications",
attribute.Int("page", page),
attribute.Int("page_size", pageSize),
attribute.String("notification_type", notificationType),
attribute.String("status", status),
attribute.String("sent_after", sentAfter),
attribute.String("sent_before", sentBefore),
)
defer observability.FinishSpan(span, &err)
// Build WHERE clause
whereConditions := []string{}
args := []interface{}{}
argIndex := 1
if notificationType != "" {
whereConditions = append(whereConditions, fmt.Sprintf("notification_type = $%d", argIndex))
args = append(args, notificationType)
argIndex++
}
if status != "" {
whereConditions = append(whereConditions, fmt.Sprintf("status = $%d", argIndex))
args = append(args, status)
argIndex++
}
if sentAfter != "" {
whereConditions = append(whereConditions, fmt.Sprintf("sent_at >= $%d", argIndex))
args = append(args, sentAfter)
argIndex++
}
if sentBefore != "" {
whereConditions = append(whereConditions, fmt.Sprintf("sent_at <= $%d", argIndex))
args = append(args, sentBefore)
argIndex++
}
whereClause := ""
if len(whereConditions) > 0 {
whereClause = "WHERE " + strings.Join(whereConditions, " AND ")
}
// Get total count
var totalNotifications int
countQuery := fmt.Sprintf("SELECT COUNT(*) FROM sent_notifications %s", whereClause)
err = s.db.QueryRowContext(ctx, countQuery, args...).Scan(&totalNotifications)
if err != nil {
s.logger.Error(ctx, "Failed to get total sent notifications", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to get total sent notifications")
}
// Calculate pagination
offset := (page - 1) * pageSize
totalPages := (totalNotifications + pageSize - 1) / pageSize
// Get notifications with pagination
args = append(args, pageSize, offset)
query := fmt.Sprintf(`
SELECT sn.id, sn.user_id, u.username, u.email, sn.notification_type,
sn.subject, sn.template_name, sn.sent_at, sn.status, sn.error_message, sn.retry_count
FROM sent_notifications sn
LEFT JOIN users u ON sn.user_id = u.id
%s
ORDER BY sn.sent_at DESC
LIMIT $%d OFFSET $%d
`, whereClause, argIndex, argIndex+1)
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
s.logger.Error(ctx, "Failed to get sent notifications", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to get sent notifications")
}
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
var notifications []map[string]interface{}
for rows.Next() {
var notification map[string]interface{}
var id, userID int
var username, notificationType, subject, templateName, status string
var sentAt time.Time
var errorMessage sql.NullString
var retryCount int
var email sql.NullString
err := rows.Scan(&id, &userID, &username, &email, ¬ificationType, &subject, &templateName, &sentAt, &status, &errorMessage, &retryCount)
if err != nil {
s.logger.Error(ctx, "Failed to scan sent notification", err, map[string]interface{}{})
return nil, nil, nil, contextutils.WrapError(err, "failed to scan sent notification")
}
notification = map[string]interface{}{
"id": id,
"user_id": userID,
"username": username,
"notification_type": notificationType,
"subject": subject,
"template_name": templateName,
"sent_at": sentAt.Format(time.RFC3339),
"status": status,
"retry_count": retryCount,
}
if email.Valid {
notification["email_address"] = email.String
} else {
notification["email_address"] = ""
}
if errorMessage.Valid {
notification["error_message"] = errorMessage.String
}
notifications = append(notifications, notification)
}
// Get stats
stats := map[string]interface{}{
"total_sent": 0,
"total_failed": 0,
"success_rate": 0.0,
"sent_today": 0,
"sent_this_week": 0,
}
// Get total sent
var totalSent int
err = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sent_notifications WHERE status = 'sent'").Scan(&totalSent)
if err != nil {
s.logger.Error(ctx, "Failed to get total sent", err, map[string]interface{}{})
} else {
stats["total_sent"] = totalSent
}
// Get total failed
var totalFailed int
err = s.db.QueryRowContext(ctx, "SELECT COUNT(*) FROM sent_notifications WHERE status = 'failed'").Scan(&totalFailed)
if err != nil {
s.logger.Error(ctx, "Failed to get total failed", err, map[string]interface{}{})
} else {
stats["total_failed"] = totalFailed
}
// Calculate success rate
if totalSent+totalFailed > 0 {
stats["success_rate"] = float64(totalSent) / float64(totalSent+totalFailed)
}
// Get sent today
var sentToday int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sent_notifications
WHERE status = 'sent' AND DATE(sent_at) = CURRENT_DATE
`).Scan(&sentToday)
if err != nil {
s.logger.Error(ctx, "Failed to get sent today", err, map[string]interface{}{})
} else {
stats["sent_today"] = sentToday
}
// Get sent this week
var sentThisWeek int
err = s.db.QueryRowContext(ctx, `
SELECT COUNT(*) FROM sent_notifications
WHERE status = 'sent' AND sent_at >= DATE_TRUNC('week', CURRENT_DATE)
`).Scan(&sentThisWeek)
if err != nil {
s.logger.Error(ctx, "Failed to get sent this week", err, map[string]interface{}{})
} else {
stats["sent_this_week"] = sentThisWeek
}
// Get notifications by type
notificationsByType := make(map[string]int)
rows, err = s.db.QueryContext(ctx, "SELECT notification_type, COUNT(*) FROM sent_notifications GROUP BY notification_type")
if err != nil {
s.logger.Error(ctx, "Failed to get notifications by type", err, map[string]interface{}{})
} else {
defer func() {
if closeErr := rows.Close(); closeErr != nil {
s.logger.Error(ctx, "Failed to close rows", closeErr, map[string]interface{}{})
}
}()
for rows.Next() {
var notificationType string
var count int
if err := rows.Scan(¬ificationType, &count); err != nil {
s.logger.Error(ctx, "Failed to scan notifications by type", err, map[string]interface{}{})
continue
}
notificationsByType[notificationType] = count
}
stats["notifications_by_type"] = notificationsByType
}
pagination := map[string]interface{}{
"page": page,
"page_size": pageSize,
"total": totalNotifications,
"total_pages": totalPages,
}
s.logger.Debug(ctx, "Retrieved sent notifications", map[string]interface{}{
"count": len(notifications), "page": page, "total": totalNotifications,
})
return notifications, pagination, stats, nil
}
// CreateTestSentNotification creates a test sent notification for testing purposes
func (s *WorkerService) CreateTestSentNotification(ctx context.Context, userID int, notificationType, subject, templateName, status, errorMessage string) error {
ctx, span := observability.TraceWorkerFunction(ctx, "create_test_sent_notification",
attribute.Int("user.id", userID),
attribute.String("notification.type", notificationType),
attribute.String("notification.status", status),
)
defer span.End()
query := `
INSERT INTO sent_notifications (user_id, notification_type, subject, template_name, sent_at, status, error_message)
VALUES ($1, $2, $3, $4, $5, $6, $7)
`
_, err := s.db.ExecContext(ctx, query, userID, notificationType, subject, templateName, time.Now(), status, errorMessage)
if err != nil {
span.RecordError(err)
s.logger.Error(ctx, "Failed to create test sent notification", err, map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
"status": status,
})
return contextutils.WrapError(err, "failed to create test sent notification")
}
s.logger.Info(ctx, "Created test sent notification", map[string]interface{}{
"user_id": userID,
"notification_type": notificationType,
"status": status,
})
return nil
}
// Package contextutils provides error handling utilities and standardized error types
// for consistent error management across the quiz application.
package contextutils
import (
"fmt"
"strings"
)
// ErrorCode represents a standardized error code for API responses
type ErrorCode string
const (
// Database error codes
// ErrorCodeDatabaseConnection indicates a database connection error
ErrorCodeDatabaseConnection ErrorCode = "DATABASE_CONNECTION_ERROR"
// ErrorCodeDatabaseQuery indicates a database query error
ErrorCodeDatabaseQuery ErrorCode = "DATABASE_QUERY_ERROR"
// ErrorCodeDatabaseTransaction indicates a database transaction error
ErrorCodeDatabaseTransaction ErrorCode = "DATABASE_TRANSACTION_ERROR"
// ErrorCodeRecordNotFound indicates that a requested record was not found
ErrorCodeRecordNotFound ErrorCode = "RECORD_NOT_FOUND"
// ErrorCodeRecordExists indicates that a record already exists (duplicate key)
ErrorCodeRecordExists ErrorCode = "RECORD_ALREADY_EXISTS"
// ErrorCodeForeignKeyViolation indicates a foreign key constraint violation
ErrorCodeForeignKeyViolation ErrorCode = "FOREIGN_KEY_VIOLATION"
// Validation error codes
// ErrorCodeInvalidInput indicates that the provided input is invalid
ErrorCodeInvalidInput ErrorCode = "INVALID_INPUT"
// ErrorCodeMissingRequired indicates that a required field is missing
ErrorCodeMissingRequired ErrorCode = "MISSING_REQUIRED_FIELD"
// ErrorCodeInvalidFormat indicates that the input format is invalid
ErrorCodeInvalidFormat ErrorCode = "INVALID_FORMAT"
// ErrorCodeValidationFailed indicates that validation has failed
ErrorCodeValidationFailed ErrorCode = "VALIDATION_FAILED"
// Authentication error codes
// ErrorCodeUnauthorized indicates that the user is not authorized
ErrorCodeUnauthorized ErrorCode = "UNAUTHORIZED"
// ErrorCodeForbidden indicates that the user is forbidden from accessing the resource
ErrorCodeForbidden ErrorCode = "FORBIDDEN"
// ErrorCodeInvalidCredentials indicates that the provided credentials are invalid
ErrorCodeInvalidCredentials ErrorCode = "INVALID_CREDENTIALS"
// ErrorCodeSessionExpired indicates that the user session has expired
ErrorCodeSessionExpired ErrorCode = "SESSION_EXPIRED"
// Service error codes
// ErrorCodeServiceUnavailable indicates that the service is temporarily unavailable
ErrorCodeServiceUnavailable ErrorCode = "SERVICE_UNAVAILABLE"
// ErrorCodeTimeout indicates that a request has timed out
ErrorCodeTimeout ErrorCode = "REQUEST_TIMEOUT"
// ErrorCodeRateLimit indicates that the rate limit has been exceeded
ErrorCodeRateLimit ErrorCode = "RATE_LIMIT_EXCEEDED"
// ErrorCodeInternalError indicates an internal server error
ErrorCodeInternalError ErrorCode = "INTERNAL_SERVER_ERROR"
// ErrorCodeAssignmentNotFound indicates that a question assignment was not found
ErrorCodeAssignmentNotFound ErrorCode = "ASSIGNMENT_NOT_FOUND"
// Question error codes
// ErrorCodeTimestampMissingTimezone indicates that a timestamp is missing timezone information
ErrorCodeTimestampMissingTimezone ErrorCode = "TIMESTAMP_MISSING_TIMEZONE"
// ErrorCodeNoQuestionsAvailable indicates that no questions are available
ErrorCodeNoQuestionsAvailable ErrorCode = "NO_QUESTIONS_AVAILABLE"
// ErrorCodeQuestionAlreadyAnswered indicates that the question has already been answered
ErrorCodeQuestionAlreadyAnswered ErrorCode = "QUESTION_ALREADY_ANSWERED"
// ErrorCodeQuestionNotFound indicates that the requested question was not found
ErrorCodeQuestionNotFound ErrorCode = "QUESTION_NOT_FOUND"
// ErrorCodeInvalidAnswerIndex indicates that the answer index is invalid
ErrorCodeInvalidAnswerIndex ErrorCode = "INVALID_ANSWER_INDEX"
// AI Service error codes
// ErrorCodeAIProviderUnavailable indicates that the AI provider is unavailable
ErrorCodeAIProviderUnavailable ErrorCode = "AI_PROVIDER_UNAVAILABLE"
// ErrorCodeAIRequestFailed indicates that the AI request failed
ErrorCodeAIRequestFailed ErrorCode = "AI_REQUEST_FAILED"
// ErrorCodeAIResponseInvalid indicates that the AI response is invalid
ErrorCodeAIResponseInvalid ErrorCode = "AI_RESPONSE_INVALID"
// ErrorCodeAIConfigInvalid indicates that the AI configuration is invalid
ErrorCodeAIConfigInvalid ErrorCode = "AI_CONFIG_INVALID"
// OAuth error codes
// ErrorCodeOAuthCodeExpired indicates that the OAuth authorization code has expired
ErrorCodeOAuthCodeExpired ErrorCode = "OAUTH_CODE_EXPIRED"
// ErrorCodeOAuthStateMismatch indicates that the OAuth state parameter does not match
ErrorCodeOAuthStateMismatch ErrorCode = "OAUTH_STATE_MISMATCH"
// ErrorCodeOAuthProviderError indicates an error from the OAuth provider
ErrorCodeOAuthProviderError ErrorCode = "OAUTH_PROVIDER_ERROR"
)
// SeverityLevel represents the severity of an error for logging and monitoring
type SeverityLevel string
const (
// SeverityDebug indicates debug-level errors for development
SeverityDebug SeverityLevel = "debug"
// SeverityInfo indicates informational errors
SeverityInfo SeverityLevel = "info"
// SeverityWarn indicates warning-level errors
SeverityWarn SeverityLevel = "warn"
// SeverityError indicates error-level issues
SeverityError SeverityLevel = "error"
// SeverityFatal indicates fatal errors that require immediate attention
SeverityFatal SeverityLevel = "fatal"
)
// AppError represents a structured error with code, severity, and context
type AppError struct {
Code ErrorCode
Severity SeverityLevel
Message string
Details string
Cause error
}
// Error implements the error interface
func (e *AppError) Error() string {
if e.Details != "" {
return fmt.Sprintf("%s: %s - %s", e.Code, e.Message, e.Details)
}
return fmt.Sprintf("%s: %s", e.Code, e.Message)
}
// Unwrap returns the underlying cause error
func (e *AppError) Unwrap() error {
return e.Cause
}
// Is implements error comparison for errors.Is
func (e *AppError) Is(target error) bool {
if appErr, ok := target.(*AppError); ok {
return e.Code == appErr.Code
}
return false
}
// Error types for consistent error handling with associated codes and severity
var (
// Database errors
ErrDatabaseConnection = &AppError{
Code: ErrorCodeDatabaseConnection,
Severity: SeverityError,
Message: "Database connection failed",
}
ErrDatabaseQuery = &AppError{
Code: ErrorCodeDatabaseQuery,
Severity: SeverityError,
Message: "Database query failed",
}
ErrDatabaseTransaction = &AppError{
Code: ErrorCodeDatabaseTransaction,
Severity: SeverityError,
Message: "Database transaction failed",
}
ErrRecordNotFound = &AppError{
Code: ErrorCodeRecordNotFound,
Severity: SeverityInfo,
Message: "Record not found",
}
ErrRecordExists = &AppError{
Code: ErrorCodeRecordExists,
Severity: SeverityInfo,
Message: "Record already exists",
}
ErrForeignKeyViolation = &AppError{
Code: ErrorCodeForeignKeyViolation,
Severity: SeverityError,
Message: "Foreign key constraint violation",
}
// Validation errors
ErrInvalidInput = &AppError{
Code: ErrorCodeInvalidInput,
Severity: SeverityWarn,
Message: "Invalid input",
}
ErrMissingRequired = &AppError{
Code: ErrorCodeMissingRequired,
Severity: SeverityWarn,
Message: "Missing required field",
}
ErrInvalidFormat = &AppError{
Code: ErrorCodeInvalidFormat,
Severity: SeverityWarn,
Message: "Invalid format",
}
ErrValidationFailed = &AppError{
Code: ErrorCodeValidationFailed,
Severity: SeverityWarn,
Message: "Validation failed",
}
// Authentication errors
ErrUnauthorized = &AppError{
Code: ErrorCodeUnauthorized,
Severity: SeverityWarn,
Message: "Unauthorized",
}
ErrForbidden = &AppError{
Code: ErrorCodeForbidden,
Severity: SeverityWarn,
Message: "Forbidden",
}
ErrInvalidCredentials = &AppError{
Code: ErrorCodeInvalidCredentials,
Severity: SeverityWarn,
Message: "Invalid credentials",
}
ErrSessionExpired = &AppError{
Code: ErrorCodeSessionExpired,
Severity: SeverityInfo,
Message: "Session expired",
}
// Service errors
ErrServiceUnavailable = &AppError{
Code: ErrorCodeServiceUnavailable,
Severity: SeverityError,
Message: "Service unavailable",
}
ErrTimeout = &AppError{
Code: ErrorCodeTimeout,
Severity: SeverityWarn,
Message: "Request timeout",
}
ErrRateLimit = &AppError{
Code: ErrorCodeRateLimit,
Severity: SeverityWarn,
Message: "Rate limit exceeded",
}
ErrInternalError = &AppError{
Code: ErrorCodeInternalError,
Severity: SeverityError,
Message: "Internal server error",
}
ErrAssignmentNotFound = &AppError{
Code: ErrorCodeAssignmentNotFound,
Severity: SeverityInfo,
Message: "Assignment not found",
}
// Question errors
ErrTimestampMissingTimezone = &AppError{
Code: ErrorCodeTimestampMissingTimezone,
Severity: SeverityError,
Message: "Timestamp missing timezone",
}
ErrNoQuestionsAvailable = &AppError{
Code: ErrorCodeNoQuestionsAvailable,
Severity: SeverityInfo,
Message: "No questions available for assignment",
}
ErrQuestionAlreadyAnswered = &AppError{
Code: ErrorCodeQuestionAlreadyAnswered,
Severity: SeverityInfo,
Message: "Question already answered",
}
ErrQuestionNotFound = &AppError{
Code: ErrorCodeQuestionNotFound,
Severity: SeverityInfo,
Message: "Question not found",
}
ErrInvalidAnswerIndex = &AppError{
Code: ErrorCodeInvalidAnswerIndex,
Severity: SeverityWarn,
Message: "Invalid answer index",
}
// AI Service errors
ErrAIProviderUnavailable = &AppError{
Code: ErrorCodeAIProviderUnavailable,
Severity: SeverityError,
Message: "AI provider unavailable",
}
ErrAIRequestFailed = &AppError{
Code: ErrorCodeAIRequestFailed,
Severity: SeverityError,
Message: "AI request failed",
}
ErrAIResponseInvalid = &AppError{
Code: ErrorCodeAIResponseInvalid,
Severity: SeverityError,
Message: "AI response invalid",
}
ErrAIConfigInvalid = &AppError{
Code: ErrorCodeAIConfigInvalid,
Severity: SeverityError,
Message: "AI configuration invalid",
}
// OAuth errors
ErrOAuthCodeExpired = &AppError{
Code: ErrorCodeOAuthCodeExpired,
Severity: SeverityWarn,
Message: "OAuth code expired",
}
ErrOAuthStateMismatch = &AppError{
Code: ErrorCodeOAuthStateMismatch,
Severity: SeverityError,
Message: "OAuth state mismatch",
}
ErrOAuthProviderError = &AppError{
Code: ErrorCodeOAuthProviderError,
Severity: SeverityError,
Message: "OAuth provider error",
}
)
// NewAppError creates a new AppError with the specified code, severity, message and details
func NewAppError(code ErrorCode, severity SeverityLevel, message, details string) *AppError {
return &AppError{
Code: code,
Severity: severity,
Message: message,
Details: details,
}
}
// NewAppErrorWithCause creates a new AppError with an underlying cause
func NewAppErrorWithCause(code ErrorCode, severity SeverityLevel, message, details string, cause error) *AppError {
return &AppError{
Code: code,
Severity: severity,
Message: message,
Details: details,
Cause: cause,
}
}
// WrapError wraps an error with additional context, preserving AppError structure if possible
func WrapError(err error, context string) error {
if err == nil {
return nil
}
// If it's already an AppError, wrap it with additional details
if appErr, ok := err.(*AppError); ok {
return &AppError{
Code: appErr.Code,
Severity: appErr.Severity,
Message: context,
Details: appErr.Error(),
Cause: appErr,
}
}
// For regular errors, create a generic internal error wrapper
return &AppError{
Code: ErrorCodeInternalError,
Severity: SeverityError,
Message: context,
Details: err.Error(),
Cause: err,
}
}
// WrapErrorf wraps an error with formatted context, preserving AppError structure if possible
func WrapErrorf(err error, format string, args ...interface{}) error {
if err == nil {
return nil
}
// Handle %w verb for error wrapping by using fmt.Errorf
if strings.Contains(format, "%w") {
// Use fmt.Errorf to properly handle %w verb
wrappedErr := fmt.Errorf(format, args...)
// If it's already an AppError, wrap it with the formatted message
if appErr, ok := err.(*AppError); ok {
return &AppError{
Code: appErr.Code,
Severity: appErr.Severity,
Message: wrappedErr.Error(),
Details: appErr.Error(),
Cause: appErr,
}
}
// For regular errors, wrap with the formatted error
return &AppError{
Code: ErrorCodeInternalError,
Severity: SeverityError,
Message: wrappedErr.Error(),
Details: err.Error(),
Cause: err,
}
}
// If it's already an AppError, wrap it with additional details
if appErr, ok := err.(*AppError); ok {
context := fmt.Sprintf(format, args...)
return &AppError{
Code: appErr.Code,
Severity: appErr.Severity,
Message: context,
Details: appErr.Error(),
Cause: appErr,
}
}
// For regular errors, create a generic internal error wrapper
context := fmt.Sprintf(format, args...)
return &AppError{
Code: ErrorCodeInternalError,
Severity: SeverityError,
Message: context,
Details: err.Error(),
Cause: err,
}
}
// ErrorWithContextf creates a new error with formatted context
func ErrorWithContextf(format string, args ...interface{}) error {
return &AppError{
Code: ErrorCodeInternalError,
Severity: SeverityError,
Message: fmt.Sprintf(format, args...),
}
}
// IsError checks if an error matches a specific AppError type
func IsError(err error, target *AppError) bool {
if appErr, ok := err.(*AppError); ok {
return appErr.Code == target.Code
}
return false
}
// AsError attempts to convert an error to an AppError
func AsError(err error, target **AppError) bool {
if appErr, ok := err.(*AppError); ok {
*target = appErr
return true
}
return false
}
// GetErrorCode returns the error code from an error if it's an AppError, otherwise returns a default code
func GetErrorCode(err error) ErrorCode {
if appErr, ok := err.(*AppError); ok {
return appErr.Code
}
return ErrorCodeInternalError
}
// GetErrorSeverity returns the severity level from an error if it's an AppError, otherwise returns error
func GetErrorSeverity(err error) SeverityLevel {
if appErr, ok := err.(*AppError); ok {
return appErr.Severity
}
return SeverityError
}
// IsRetryable determines if an error should be retried based on its type and severity
func IsRetryable(err error) bool {
if appErr, ok := err.(*AppError); ok {
// Only retry certain types of errors that are likely transient
switch appErr.Code {
case ErrorCodeTimeout, ErrorCodeServiceUnavailable, ErrorCodeDatabaseConnection:
return appErr.Severity != SeverityFatal
}
}
return false
}
// GetErrorLocalizedMessage returns a localized message for the error
func GetErrorLocalizedMessage(err error, locale string) string {
if appErr, ok := err.(*AppError); ok {
return GetLocalizedMessageWithDetails(appErr.Code, ParseLocale(locale), appErr.Details)
}
return "An error occurred"
}
// ToJSON converts an AppError to a JSON-serializable structure for API responses
func (e *AppError) ToJSON() map[string]interface{} {
result := map[string]interface{}{
"code": string(e.Code),
"message": e.Message,
"severity": string(e.Severity),
"error": e.Message, // Include error field for backward compatibility
}
if e.Details != "" {
result["details"] = e.Details
}
// Add retryable information
result["retryable"] = IsRetryable(e)
if e.Cause != nil {
// Only include cause in debug mode or for certain error types
switch e.Severity {
case SeverityError, SeverityFatal:
result["cause"] = e.Cause.Error()
}
}
return result
}
// ToJSONWithLocale converts an AppError to a JSON-serializable structure with localized messages
func (e *AppError) ToJSONWithLocale(locale string) map[string]interface{} {
result := e.ToJSON()
// Replace the message with localized version and update error field too
localizedMessage := GetLocalizedMessage(e.Code, ParseLocale(locale))
result["message"] = localizedMessage
result["error"] = localizedMessage // Keep error field in sync
return result
}
package contextutils
import (
"encoding/json"
"fmt"
"strings"
)
// Locale represents a language locale (e.g., "en", "es", "fr")
type Locale string
const (
// LocaleEnglish represents English language
LocaleEnglish Locale = "en"
// LocaleSpanish represents Spanish language
LocaleSpanish Locale = "es"
// LocaleFrench represents French language
LocaleFrench Locale = "fr"
// LocaleGerman represents German language
LocaleGerman Locale = "de"
// LocaleItalian represents Italian language
LocaleItalian Locale = "it"
)
// LocalizedMessages contains localized error messages for different locales
type LocalizedMessages struct {
messages map[ErrorCode]map[Locale]string
}
// NewLocalizedMessages creates a new instance of localized messages
func NewLocalizedMessages() *LocalizedMessages {
return &LocalizedMessages{
messages: make(map[ErrorCode]map[Locale]string),
}
}
// AddMessage adds a localized message for a specific error code and locale
func (lm *LocalizedMessages) AddMessage(code ErrorCode, locale Locale, message string) {
if lm.messages[code] == nil {
lm.messages[code] = make(map[Locale]string)
}
lm.messages[code][locale] = message
}
// GetMessage returns the localized message for an error code and locale
func (lm *LocalizedMessages) GetMessage(code ErrorCode, locale Locale) string {
// Try to get the message for the specific locale
if localeMessages, exists := lm.messages[code]; exists {
if message, exists := localeMessages[locale]; exists {
return message
}
// Fallback to English if the specific locale doesn't have a message
if message, exists := localeMessages[LocaleEnglish]; exists {
return message
}
}
// Fallback to a default message
return getDefaultMessage(code)
}
// GetMessageWithDetails returns a localized message with additional details
func (lm *LocalizedMessages) GetMessageWithDetails(code ErrorCode, locale Locale, details string) string {
message := lm.GetMessage(code, locale)
if details != "" {
return fmt.Sprintf("%s: %s", message, details)
}
return message
}
// getDefaultMessage returns a default English message for error codes
func getDefaultMessage(code ErrorCode) string {
switch code {
case ErrorCodeDatabaseConnection:
return "Database connection failed"
case ErrorCodeDatabaseQuery:
return "Database query failed"
case ErrorCodeDatabaseTransaction:
return "Database transaction failed"
case ErrorCodeRecordNotFound:
return "Record not found"
case ErrorCodeRecordExists:
return "Record already exists"
case ErrorCodeForeignKeyViolation:
return "Foreign key constraint violation"
case ErrorCodeInvalidInput:
return "Invalid input"
case ErrorCodeMissingRequired:
return "Missing required field"
case ErrorCodeInvalidFormat:
return "Invalid format"
case ErrorCodeValidationFailed:
return "Validation failed"
case ErrorCodeUnauthorized:
return "Unauthorized access"
case ErrorCodeForbidden:
return "Access forbidden"
case ErrorCodeInvalidCredentials:
return "Invalid credentials"
case ErrorCodeSessionExpired:
return "Session expired"
case ErrorCodeServiceUnavailable:
return "Service temporarily unavailable"
case ErrorCodeTimeout:
return "Request timeout"
case ErrorCodeRateLimit:
return "Rate limit exceeded"
case ErrorCodeInternalError:
return "Internal server error"
case ErrorCodeAssignmentNotFound:
return "Assignment not found"
case ErrorCodeTimestampMissingTimezone:
return "Timestamp missing timezone"
case ErrorCodeNoQuestionsAvailable:
return "No questions available"
case ErrorCodeQuestionAlreadyAnswered:
return "Question already answered"
case ErrorCodeQuestionNotFound:
return "Question not found"
case ErrorCodeInvalidAnswerIndex:
return "Invalid answer index"
case ErrorCodeAIProviderUnavailable:
return "AI service unavailable"
case ErrorCodeAIRequestFailed:
return "AI request failed"
case ErrorCodeAIResponseInvalid:
return "AI response invalid"
case ErrorCodeAIConfigInvalid:
return "AI configuration invalid"
case ErrorCodeOAuthCodeExpired:
return "OAuth code expired"
case ErrorCodeOAuthStateMismatch:
return "OAuth state mismatch"
case ErrorCodeOAuthProviderError:
return "OAuth provider error"
default:
return "An error occurred"
}
}
// LoadMessagesFromJSON loads localized messages from a JSON structure
func (lm *LocalizedMessages) LoadMessagesFromJSON(jsonData string) error {
var data map[string]map[string]string
if err := json.Unmarshal([]byte(jsonData), &data); err != nil {
return WrapError(err, "failed to parse localization JSON")
}
for codeStr, localeMessages := range data {
code := ErrorCode(codeStr)
for localeStr, message := range localeMessages {
locale := Locale(localeStr)
lm.AddMessage(code, locale, message)
}
}
return nil
}
// GetSupportedLocales returns a list of supported locales
func (lm *LocalizedMessages) GetSupportedLocales() []Locale {
locales := make(map[Locale]bool)
for _, localeMessages := range lm.messages {
for locale := range localeMessages {
locales[locale] = true
}
}
result := make([]Locale, 0, len(locales))
for locale := range locales {
result = append(result, locale)
}
return result
}
// ParseLocale parses a locale string (e.g., "en-US", "fr-CA") and returns the language part
func ParseLocale(localeStr string) Locale {
// Handle locale formats like "en-US", "fr-CA", etc.
parts := strings.Split(localeStr, "-")
if len(parts) > 0 && parts[0] != "" {
return Locale(strings.ToLower(parts[0]))
}
return LocaleEnglish // Default fallback
}
// Global instance of localized messages
var globalLocalizedMessages = NewLocalizedMessages()
// init loads default localized messages
func init() {
// Load some basic localized messages
globalLocalizedMessages.AddMessage(ErrorCodeInvalidInput, LocaleSpanish, "Entrada invÃlida")
globalLocalizedMessages.AddMessage(ErrorCodeInvalidInput, LocaleFrench, "EntrÃe invalide")
globalLocalizedMessages.AddMessage(ErrorCodeInvalidInput, LocaleGerman, "UngÃltige Eingabe")
globalLocalizedMessages.AddMessage(ErrorCodeRecordNotFound, LocaleSpanish, "Registro no encontrado")
globalLocalizedMessages.AddMessage(ErrorCodeRecordNotFound, LocaleFrench, "Enregistrement non trouvÃ")
globalLocalizedMessages.AddMessage(ErrorCodeRecordNotFound, LocaleGerman, "Datensatz nicht gefunden")
globalLocalizedMessages.AddMessage(ErrorCodeUnauthorized, LocaleSpanish, "Acceso no autorizado")
globalLocalizedMessages.AddMessage(ErrorCodeUnauthorized, LocaleFrench, "AccÃs non autorisÃ")
globalLocalizedMessages.AddMessage(ErrorCodeUnauthorized, LocaleGerman, "Unbefugter Zugriff")
globalLocalizedMessages.AddMessage(ErrorCodeInternalError, LocaleSpanish, "Error interno del servidor")
globalLocalizedMessages.AddMessage(ErrorCodeInternalError, LocaleFrench, "Erreur interne du serveur")
globalLocalizedMessages.AddMessage(ErrorCodeInternalError, LocaleGerman, "Interner Serverfehler")
}
// GetLocalizedMessage returns a localized error message using the global instance
func GetLocalizedMessage(code ErrorCode, locale Locale) string {
return globalLocalizedMessages.GetMessage(code, locale)
}
// GetLocalizedMessageWithDetails returns a localized error message with details
func GetLocalizedMessageWithDetails(code ErrorCode, locale Locale, details string) string {
return globalLocalizedMessages.GetMessageWithDetails(code, locale, details)
}
// SetGlobalLocalizedMessages sets the global localized messages instance
func SetGlobalLocalizedMessages(messages *LocalizedMessages) {
globalLocalizedMessages = messages
}
package contextutils
import (
"strings"
)
// MaskAPIKey masks an API key for logging purposes to prevent exposure
// Returns a masked version that shows only first 4 and last 4 characters
func MaskAPIKey(apiKey string) string {
if apiKey == "" {
return "[EMPTY]"
}
if len(apiKey) <= 8 {
return strings.Repeat("*", len(apiKey))
}
return apiKey[:4] + strings.Repeat("*", len(apiKey)-8) + apiKey[len(apiKey)-4:]
}
package contextutils
import (
"context"
"time"
"quizapp/internal/models"
)
// ParseDateInUserTimezone parses a YYYY-MM-DD date string in the user's timezone.
// The userLookup function is injected to fetch the user (to avoid tight coupling and enable testing).
// Returns the parsed time (in the location), the effective timezone name (or "UTC" on fallback), and an error.
// If the date format is invalid, the returned error will be wrapped with the message "invalid date format".
func ParseDateInUserTimezone(
ctx context.Context,
userID int,
dateStr string,
userLookup func(context.Context, int) (*models.User, error),
) (time.Time, string, error) {
user, err := userLookup(ctx, userID)
if err != nil {
return time.Time{}, "", err
}
timezone := "UTC"
if user != nil && user.Timezone.Valid && user.Timezone.String != "" {
timezone = user.Timezone.String
}
loc, err := time.LoadLocation(timezone)
if err != nil {
// Fallback to UTC if invalid timezone
loc = time.UTC
timezone = "UTC"
}
date, err := time.ParseInLocation("2006-01-02", dateStr, loc)
if err != nil {
return time.Time{}, timezone, WrapError(err, "invalid date format")
}
return date, timezone, nil
}
// ConvertTimeToUserLocation converts the provided time to the user's timezone.
// Returns the converted time and the effective timezone name (or "UTC" on fallback).
func ConvertTimeToUserLocation(
ctx context.Context,
userID int,
t time.Time,
userLookup func(context.Context, int) (*models.User, error),
) (time.Time, string, error) {
user, err := userLookup(ctx, userID)
if err != nil {
return time.Time{}, "", err
}
timezone := "UTC"
if user != nil && user.Timezone.Valid && user.Timezone.String != "" {
timezone = user.Timezone.String
}
loc, err := time.LoadLocation(timezone)
if err != nil {
loc = time.UTC
timezone = "UTC"
}
return t.In(loc), timezone, nil
}
// FormatTimeInUserTimezone formats the provided time in the user's timezone using the given layout.
// Returns the formatted string and the effective timezone name.
func FormatTimeInUserTimezone(
ctx context.Context,
userID int,
t time.Time,
layout string,
userLookup func(context.Context, int) (*models.User, error),
) (string, string, error) {
// If the stored timestamp is exactly midnight UTC with zero nanoseconds,
// it may be a date-only value (missing timezone). We only treat it as
// missing if the user has a configured timezone that is not UTC.
if t.Location() == time.UTC && t.Hour() == 0 && t.Minute() == 0 && t.Second() == 0 && t.Nanosecond() == 0 {
if userLookup != nil {
if u, err := userLookup(ctx, userID); err == nil && u != nil && u.Timezone.Valid && u.Timezone.String != "" && u.Timezone.String != "UTC" {
return "", "", ErrTimestampMissingTimezone
}
}
}
tt, tz, err := ConvertTimeToUserLocation(ctx, userID, t, userLookup)
if err != nil {
return "", tz, err
}
res := tt.Format(layout)
return res, tz, nil
}
// UserLocalDayRange returns the UTC start and end timestamps that cover the
// last `days` calendar days for the given user in their configured timezone.
// The range is [startUTC, endUTC) where startUTC is the start of the earliest
// local day at 00:00 and endUTC is the start of the day after "today" at 00:00
// in UTC. The userLookup function is used to fetch the user's timezone.
func UserLocalDayRange(ctx context.Context, userID, days int, userLookup func(context.Context, int) (*models.User, error)) (time.Time, time.Time, string, error) {
if days <= 0 {
days = 1
}
user, err := userLookup(ctx, userID)
if err != nil {
return time.Time{}, time.Time{}, "", err
}
timezone := "UTC"
if user != nil && user.Timezone.Valid && user.Timezone.String != "" {
timezone = user.Timezone.String
}
loc, err := time.LoadLocation(timezone)
if err != nil {
loc = time.UTC
timezone = "UTC"
}
now := time.Now().In(loc)
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc)
startLocal := today.AddDate(0, 0, -(days - 1))
// start of the day after today
endLocal := today.Add(24 * time.Hour)
startUTC := startLocal.UTC()
endUTC := endLocal.UTC()
return startUTC, endUTC, timezone, nil
}
package contextutils
import (
"github.com/go-playground/validator/v10"
)
var validate = validator.New()
// IsValidEmail checks if an email address is valid using go-playground/validator
func IsValidEmail(email string) bool {
return validate.Var(email, "email") == nil
}
// Package worker contains the background worker responsible for generating
// and maintaining daily question assignments, scheduling generation jobs,
// and reporting worker health. The worker runs independently of HTTP
// request handling and interacts with the database, AI providers, and
// other internal services to keep question queues primed for users.
package worker
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"math"
"os"
"strconv"
"strings"
"sync"
"time"
"quizapp/internal/config"
"quizapp/internal/models"
"quizapp/internal/observability"
"quizapp/internal/services"
"quizapp/internal/services/mailer"
contextutils "quizapp/internal/utils"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)
const (
// NoActionPrefix is used to identify actions that should not be processed
NoActionPrefix = config.NoActionPrefix
triggerThrottleWindow = config.WorkerTriggerThrottle // Prevent multiple triggers for same user within this window
)
// Status represents the current state of the worker
type Status struct {
IsRunning bool `json:"is_running"`
IsPaused bool `json:"is_paused"`
CurrentActivity string `json:"current_activity,omitempty"`
LastRunStart time.Time `json:"last_run_start"`
LastRunFinish time.Time `json:"last_run_finish"`
LastRunError string `json:"last_run_error,omitempty"`
NextRun time.Time `json:"next_run"`
}
// RunRecord tracks individual worker runs
type RunRecord struct {
StartTime time.Time `json:"start_time"`
EndTime time.Time `json:"end_time"`
Duration time.Duration `json:"duration"`
Status string `json:"status"` // Success, Failure
Details string `json:"details"`
}
// ActivityLog represents a single activity log entry
type ActivityLog struct {
Timestamp time.Time `json:"timestamp"`
Level string `json:"level"` // INFO, WARN, ERROR
Message string `json:"message"`
UserID *int `json:"user_id,omitempty"`
Username *string `json:"username,omitempty"`
}
// UserFailureInfo tracks failure information for exponential backoff
type UserFailureInfo struct {
ConsecutiveFailures int
LastFailureTime time.Time
NextRetryTime time.Time
}
// Config holds worker-specific configuration
type Config struct {
StartWorkerPaused bool
DailyHorizonDays int
}
// Worker manages AI question generation in the background
type Worker struct {
userService services.UserServiceInterface
questionService services.QuestionServiceInterface
aiService services.AIServiceInterface
learningService services.LearningServiceInterface
workerService services.WorkerServiceInterface
dailyQuestionService services.DailyQuestionServiceInterface
emailService mailer.Mailer
hintService services.GenerationHintServiceInterface
instance string
status Status
history []RunRecord
activityLogs []ActivityLog // Circular buffer for recent activity logs
mu sync.RWMutex
manualTrigger chan bool
cfg *config.Config
workerCfg Config
logger *observability.Logger
// Track failures for exponential backoff
userFailures map[int]*UserFailureInfo // userID -> failure info
failureMu sync.RWMutex // mutex for failure tracking
// Time function for testing - defaults to time.Now
timeNow func() time.Time
cancel context.CancelFunc // Added for cleanup
}
// checkForDailyReminders checks if any users need daily reminder emails
func (w *Worker) checkForDailyReminders(ctx context.Context) error {
ctx, span := otel.Tracer("worker").Start(ctx, "checkForDailyReminders",
trace.WithAttributes(
attribute.String("worker.instance", w.instance),
attribute.Bool("email.daily_reminder.enabled", w.cfg.Email.DailyReminder.Enabled),
attribute.Int("email.daily_reminder.hour", w.cfg.Email.DailyReminder.Hour),
attribute.Bool("email.enabled", w.cfg.Email.Enabled),
),
)
defer span.End()
if !w.cfg.Email.DailyReminder.Enabled {
w.logger.Info(ctx, "Daily reminders disabled, skipping", nil)
return nil
}
// Get current time in UTC
now := w.timeNow().UTC()
currentHour := now.Hour()
// Check if it's time to send reminders (default: 9 AM)
reminderHour := w.cfg.Email.DailyReminder.Hour
if currentHour != reminderHour {
span.SetAttributes(
attribute.Int("check.current_hour", currentHour),
attribute.Int("check.reminder_hour", reminderHour),
attribute.Bool("check.should_send", false),
attribute.String("check.reason", "wrong_hour"),
)
return nil
}
span.SetAttributes(
attribute.Int("check.current_hour", currentHour),
attribute.Int("check.reminder_hour", reminderHour),
attribute.Bool("check.should_send", true),
)
w.logger.Info(ctx, "Checking for users needing daily reminders", map[string]interface{}{
"reminder_hour": reminderHour,
})
// Get users who need daily reminders
users, err := w.getUsersNeedingDailyReminders(ctx)
if err != nil {
span.RecordError(err)
span.SetAttributes(
attribute.Int("users.total", 0),
attribute.Int("users.eligible", 0),
attribute.Int("reminders.sent", 0),
)
w.logger.Error(ctx, "Failed to get users needing daily reminders", err, nil)
return contextutils.WrapError(err, "failed to get users needing daily reminders")
}
span.SetAttributes(
attribute.Int("users.total", len(users)),
)
remindersSent := 0
failedReminders := 0
for _, user := range users {
// Record the sent notification
subject := "Time for your daily quiz! ð"
status := "sent"
errorMsg := ""
if err := w.emailService.SendDailyReminder(ctx, &user); err != nil {
failedReminders++
status = "failed"
errorMsg = err.Error()
w.logger.Error(ctx, "Failed to send daily reminder", err, map[string]interface{}{
"user_id": user.ID,
"email": user.Email.String,
})
} else {
remindersSent++
}
// Record the sent notification in the database
if err := w.emailService.RecordSentNotification(ctx, user.ID, "daily_reminder", subject, "daily_reminder", status, errorMsg); err != nil {
w.logger.Error(ctx, "Failed to record sent notification", err, map[string]interface{}{
"user_id": user.ID,
})
}
// Update the last reminder sent timestamp for this user
if err := w.learningService.UpdateLastDailyReminderSent(ctx, user.ID); err != nil {
w.logger.Error(ctx, "Failed to update last daily reminder sent timestamp", err, map[string]interface{}{
"user_id": user.ID,
})
// Don't count this as a failed reminder since the email was sent successfully
}
}
span.SetAttributes(
attribute.Int("users.eligible", len(users)),
attribute.Int("reminders.sent", remindersSent),
attribute.Int("reminders.failed", failedReminders),
attribute.Float64("reminders.success_rate", float64(remindersSent)/float64(len(users))),
)
w.logger.Info(ctx, "Daily reminders processed", map[string]interface{}{
"total_users": len(users),
"reminders_sent": remindersSent,
"reminder_hour": reminderHour,
})
return nil
}
// getUsersNeedingDailyReminders returns users who should receive daily reminders
func (w *Worker) getUsersNeedingDailyReminders(ctx context.Context) ([]models.User, error) {
ctx, span := otel.Tracer("worker").Start(ctx, "getUsersNeedingDailyReminders")
defer span.End()
// Get all users and filter for those with email addresses and daily reminders enabled
users, err := w.userService.GetAllUsers(ctx)
if err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "failed to get users")
}
var eligibleUsers []models.User
today := w.timeNow().UTC().Format("2006-01-02")
for _, user := range users {
// Check if user has email address
if !user.Email.Valid || user.Email.String == "" {
continue
}
// Get user's learning preferences to check daily reminder setting
prefs, err := w.learningService.GetUserLearningPreferences(ctx, user.ID)
if err != nil {
w.logger.Warn(ctx, "Failed to get user learning preferences for daily reminder check", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"error": err.Error(),
})
continue
}
// Check if daily reminders are enabled for this user
if prefs == nil || !prefs.DailyReminderEnabled {
continue
}
// Check if we've already sent a reminder today
if prefs.LastDailyReminderSent != nil {
lastReminderDate := prefs.LastDailyReminderSent.Format("2006-01-02")
if lastReminderDate == today {
continue
}
}
eligibleUsers = append(eligibleUsers, user)
}
w.logger.Info(ctx, "Found users eligible for daily reminders", map[string]interface{}{
"total_users": len(users),
"eligible_users": len(eligibleUsers),
})
return eligibleUsers, nil
}
// checkForDailyQuestionAssignments assigns daily questions to all eligible users
// This runs independently of email reminders to ensure users get daily questions
// even if they have email reminders disabled
func (w *Worker) checkForDailyQuestionAssignments(ctx context.Context) error {
ctx, span := observability.TraceWorkerFunction(ctx, "check_for_daily_question_assignments",
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
w.logger.Info(ctx, "Checking for daily question assignments", map[string]interface{}{
"instance": w.instance,
})
// Get users who are eligible for daily questions
users, err := w.getUsersEligibleForDailyQuestions(ctx)
if err != nil {
span.RecordError(err)
w.logger.Error(ctx, "Failed to get users eligible for daily questions", err, nil)
return contextutils.WrapError(err, "failed to get users eligible for daily questions")
}
if len(users) == 0 {
w.logger.Info(ctx, "No users eligible for daily question assignments", map[string]interface{}{
"instance": w.instance,
})
return nil
}
span.SetAttributes(
attribute.Int("users.total", len(users)),
)
successfulAssignments := 0
failedAssignments := 0
for _, user := range users {
// Get user's timezone, default to UTC if not set
timezone := "UTC"
if user.Timezone.Valid && user.Timezone.String != "" {
timezone = user.Timezone.String
}
// Get today's date in the user's timezone
loc, err := time.LoadLocation(timezone)
if err != nil {
w.logger.Warn(ctx, "Invalid timezone for user, using UTC", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"timezone": timezone,
"error": err.Error(),
})
loc = time.UTC
}
// Get today's date in the user's timezone
now := w.timeNow().In(loc)
today := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, loc)
// Assign daily questions for dates in [today .. today+N]
horizon := w.workerCfg.DailyHorizonDays
if horizon <= 0 {
// default to 2 days ahead when misconfigured or not set
horizon = 2
}
// Ensure the worker horizon covers the configured avoid window so
// that when future assignments are removed (e.g., after a correct
// submission) the worker run will top up missing slots. Use server
// config as the source of truth for the avoid window.
avoidDays := 7
if w.cfg != nil && w.cfg.Server.DailyRepeatAvoidDays > 0 {
avoidDays = w.cfg.Server.DailyRepeatAvoidDays
}
if horizon < avoidDays {
w.logger.Info(ctx, "Extending worker daily horizon to cover daily repeat avoid window", map[string]interface{}{
"old_horizon": horizon,
"new_horizon": avoidDays,
"user_id": user.ID,
})
horizon = avoidDays
}
for d := 0; d <= horizon; d++ {
target := today.AddDate(0, 0, d)
// Assign daily questions for target date in user's timezone
if err := w.dailyQuestionService.AssignDailyQuestions(ctx, user.ID, target); err != nil {
failedAssignments++
w.logger.Error(ctx, "Failed to assign daily questions", err, map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"timezone": timezone,
"date": target.Format("2006-01-02"),
})
} else {
successfulAssignments++
w.logger.Info(ctx, "Successfully assigned daily questions", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
"timezone": timezone,
"date": target.Format("2006-01-02"),
})
}
}
}
span.SetAttributes(
attribute.Int("assignments.successful", successfulAssignments),
attribute.Int("assignments.failed", failedAssignments),
)
w.logger.Info(ctx, "Completed daily question assignment check", map[string]interface{}{
"instance": w.instance,
"eligible_users": len(users),
"successful_assignments": successfulAssignments,
"failed_assignments": failedAssignments,
})
return nil
}
// getUsersEligibleForDailyQuestions returns users who should receive daily questions
// This is independent of email reminder preferences
func (w *Worker) getUsersEligibleForDailyQuestions(ctx context.Context) ([]models.User, error) {
ctx, span := otel.Tracer("worker").Start(ctx, "getUsersEligibleForDailyQuestions")
defer span.End()
// Get all users
users, err := w.userService.GetAllUsers(ctx)
if err != nil {
span.RecordError(err)
return nil, contextutils.WrapError(err, "failed to get users")
}
var eligibleUsers []models.User
for _, user := range users {
// Check if user has language and level preferences set
if !user.PreferredLanguage.Valid || user.PreferredLanguage.String == "" {
w.logger.Debug(ctx, "User missing preferred language, skipping daily question assignment", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
})
continue
}
if !user.CurrentLevel.Valid || user.CurrentLevel.String == "" {
w.logger.Debug(ctx, "User missing current level, skipping daily question assignment", map[string]interface{}{
"user_id": user.ID,
"username": user.Username,
})
continue
}
eligibleUsers = append(eligibleUsers, user)
}
w.logger.Info(ctx, "Found users eligible for daily questions", map[string]interface{}{
"total_users": len(users),
"eligible_users": len(eligibleUsers),
})
return eligibleUsers, nil
}
// NewWorker creates a new Worker instance
func NewWorker(userService services.UserServiceInterface, questionService services.QuestionServiceInterface, aiService services.AIServiceInterface, learningService services.LearningServiceInterface, workerService services.WorkerServiceInterface, dailyQuestionService services.DailyQuestionServiceInterface, emailService mailer.Mailer, hintService services.GenerationHintServiceInterface, instance string, cfg *config.Config, logger *observability.Logger) *Worker {
if instance == "" {
instance = "default"
}
ctx, cancel := context.WithCancel(context.Background())
// Prefer value from config file when set (>0). If not set, default to 1.
dailyHorizon := cfg.Server.DailyHorizonDays
if dailyHorizon <= 0 {
dailyHorizon = 1
}
w := &Worker{
userService: userService,
questionService: questionService,
aiService: aiService,
learningService: learningService,
workerService: workerService,
dailyQuestionService: dailyQuestionService,
emailService: emailService,
hintService: hintService,
instance: instance,
status: Status{IsRunning: false, CurrentActivity: "Initialized"},
history: make([]RunRecord, 0, cfg.Server.MaxHistory),
activityLogs: make([]ActivityLog, 0, cfg.Server.MaxActivityLogs),
manualTrigger: make(chan bool, 1),
cfg: cfg,
workerCfg: Config{StartWorkerPaused: getEnvBool("WORKER_START_PAUSED", false), DailyHorizonDays: dailyHorizon},
logger: logger,
userFailures: make(map[int]*UserFailureInfo),
timeNow: time.Now, // Default to real time
}
// Handle startup pause if configured
if w.workerCfg.StartWorkerPaused {
w.handleStartupPause(ctx)
}
// Store cancel function for cleanup
w.cancel = cancel
return w
}
// getEnvBool is a helper function to get boolean environment variables
func getEnvBool(key string, defaultValue bool) bool {
valStr := os.Getenv(key)
if valStr == "" {
return defaultValue
}
val, err := strconv.ParseBool(valStr)
if err != nil {
return defaultValue
}
return val
}
// Start begins the worker's background processing loop
func (w *Worker) Start(ctx context.Context) {
w.status.IsRunning = true
w.updateDatabaseStatus(ctx)
w.handleStartupPause(ctx)
// Start heartbeat goroutine
go w.heartbeatLoop(ctx)
// Main worker loop
ticker := time.NewTicker(config.WorkerHeartbeatInterval) // Check every 30 seconds
defer ticker.Stop()
initialStatus := w.getInitialWorkerStatus(ctx)
w.logger.Info(ctx, "Worker started", map[string]interface{}{
"instance": w.instance,
"status": initialStatus,
})
w.logActivity(ctx, "INFO", fmt.Sprintf("Worker %s started (%s)", w.instance, initialStatus), nil, nil)
for {
select {
case <-ctx.Done():
w.logger.Info(ctx, "Worker shutting down", map[string]interface{}{
"instance": w.instance,
})
w.logActivity(ctx, "INFO", fmt.Sprintf("Worker %s shutting down", w.instance), nil, nil)
w.status.IsRunning = false
w.updateDatabaseStatus(ctx)
return
case <-ticker.C:
w.run()
case <-w.manualTrigger:
w.logger.Info(ctx, "Worker triggered manually", map[string]interface{}{
"instance": w.instance,
})
w.logActivity(ctx, "INFO", fmt.Sprintf("Worker %s triggered manually", w.instance), nil, nil)
w.run()
}
}
}
// handleStartupPause sets global pause if configured
func (w *Worker) handleStartupPause(ctx context.Context) {
if w.workerCfg.StartWorkerPaused {
w.logger.Info(ctx, "Worker configured to start paused - setting global pause", map[string]interface{}{
"instance": w.instance,
})
if err := w.workerService.SetGlobalPause(ctx, true); err != nil {
w.logger.Error(ctx, "Failed to set global pause on startup", err, map[string]interface{}{
"instance": w.instance,
})
} else {
w.logger.Info(ctx, "Global pause set on startup as configured", map[string]interface{}{
"instance": w.instance,
})
}
}
}
// getInitialWorkerStatus determines the initial status string
func (w *Worker) getInitialWorkerStatus(ctx context.Context) string {
initialStatus := "running"
globalPaused, err := w.workerService.IsGlobalPaused(ctx)
if err != nil {
w.logger.Error(ctx, "Failed to check global pause status on startup", err, map[string]interface{}{
"instance": w.instance,
})
} else if globalPaused {
initialStatus = "paused (globally)"
} else {
status, err := w.workerService.GetWorkerStatus(ctx, w.instance)
if err != nil {
// Worker status not found is expected on first startup - this is normal
w.logger.Debug(ctx, "Worker status not found on startup (expected for new worker)", map[string]interface{}{
"instance": w.instance,
})
} else if status != nil && status.IsPaused {
initialStatus = "paused (instance)"
}
}
return initialStatus
}
func (w *Worker) heartbeatLoop(ctx context.Context) {
ticker := time.NewTicker(config.WorkerHeartbeatInterval) // Heartbeat every 30 seconds
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
w.updateHeartbeat(ctx)
}
}
}
// updateHeartbeat updates the heartbeat in the database
func (w *Worker) updateHeartbeat(ctx context.Context) {
if err := w.workerService.UpdateHeartbeat(ctx, w.instance); err != nil {
w.logger.Error(ctx, "Failed to update heartbeat for worker", err, map[string]interface{}{
"instance": w.instance,
})
}
}
// run executes a single worker cycle
func (w *Worker) run() {
ctx, span := observability.TraceWorkerFunction(context.Background(), "run",
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
// Ensure worker status is up to date before checking pause status
w.updateDatabaseStatus(ctx)
paused, reason := w.checkPauseStatus(ctx)
if paused {
span.SetAttributes(attribute.String("pause_reason", reason))
w.updateActivity(reason)
return
}
w.status.LastRunStart = time.Now()
w.updateDatabaseStatus(ctx)
details, err := w.generateNeededQuestions(ctx)
// Assign daily questions to all eligible users (independent of email reminders)
if err := w.checkForDailyQuestionAssignments(ctx); err != nil {
w.logger.Error(ctx, "Failed to check daily question assignments", err, map[string]interface{}{
"instance": w.instance,
})
}
// Check for daily email reminders
if err := w.checkForDailyReminders(ctx); err != nil {
w.logger.Error(ctx, "Failed to check daily reminders", err, map[string]interface{}{
"instance": w.instance,
})
}
w.status.LastRunFinish = time.Now()
if err != nil {
w.status.LastRunError = err.Error()
w.logger.Error(ctx, "Worker run failed", err, map[string]interface{}{
"instance": w.instance,
})
} else {
w.status.LastRunError = ""
}
w.recordRunHistory(details, err)
w.updateDatabaseStatus(ctx)
}
// checkPauseStatus checks global and instance pause
func (w *Worker) checkPauseStatus(ctx context.Context) (bool, string) {
globalPaused, err := w.workerService.IsGlobalPaused(ctx)
if err != nil {
w.logger.Error(ctx, "Failed to check global pause status", err, map[string]interface{}{
"instance": w.instance,
})
return true, "Error checking global pause status"
}
if globalPaused {
return true, "Globally paused"
}
status, err := w.workerService.GetWorkerStatus(ctx, w.instance)
if err != nil {
// Worker status not found might happen during startup - assume not paused
w.logger.Debug(ctx, "Worker status not found during pause check (assuming not paused)", map[string]interface{}{
"instance": w.instance,
})
return false, ""
} else if status != nil && status.IsPaused {
return true, "Worker instance paused"
}
return false, ""
}
// recordRunHistory records the run in history and trims the slice
func (w *Worker) recordRunHistory(details string, err error) {
record := RunRecord{
StartTime: w.status.LastRunStart,
EndTime: w.status.LastRunFinish,
Duration: w.status.LastRunFinish.Sub(w.status.LastRunStart),
Details: details,
}
if err != nil {
record.Status = "Failure"
} else {
record.Status = "Success"
}
w.mu.Lock()
w.history = append(w.history, record)
if len(w.history) > w.cfg.Server.MaxHistory {
w.history = w.history[len(w.history)-w.cfg.Server.MaxHistory:]
}
w.mu.Unlock()
}
// GetStatus returns the current worker status
func (w *Worker) GetStatus() Status {
w.mu.RLock()
defer w.mu.RUnlock()
return w.status
}
// GetHistory returns the worker's run history
func (w *Worker) GetHistory() []RunRecord {
w.mu.RLock()
defer w.mu.RUnlock()
// Return a copy to avoid race conditions
history := make([]RunRecord, len(w.history))
copy(history, w.history)
return history
}
// GetActivityLogs returns recent activity logs
func (w *Worker) GetActivityLogs() []ActivityLog {
w.mu.RLock()
defer w.mu.RUnlock()
// Return a copy to avoid concurrent access issues
logs := make([]ActivityLog, len(w.activityLogs))
copy(logs, w.activityLogs)
return logs
}
// GetInstance returns the worker instance name
func (w *Worker) GetInstance() string {
return w.instance
}
// GetEmailService returns the email service
func (w *Worker) GetEmailService() mailer.Mailer {
return w.emailService
}
// TriggerManualRun triggers a manual worker run
func (w *Worker) TriggerManualRun() {
ctx := context.Background()
select {
case w.manualTrigger <- true:
w.logger.Info(ctx, "Manual trigger sent to worker", map[string]interface{}{
"instance": w.instance,
})
default:
w.logger.Info(ctx, "Manual trigger already pending for worker", map[string]interface{}{
"instance": w.instance,
})
}
}
// Pause pauses the worker
func (w *Worker) Pause(ctx context.Context) {
if err := w.workerService.PauseWorker(ctx, w.instance); err != nil {
w.logger.Warn(ctx, "Failed to pause worker in service", map[string]interface{}{
"instance": w.instance,
"error": err.Error(),
})
}
w.logger.Info(ctx, "Worker paused", map[string]interface{}{
"instance": w.instance,
})
w.logActivity(ctx, "INFO", fmt.Sprintf("Worker %s paused", w.instance), nil, nil)
w.status.IsPaused = true
w.updateDatabaseStatus(ctx)
}
// Resume resumes the worker
func (w *Worker) Resume(ctx context.Context) {
if err := w.workerService.ResumeWorker(ctx, w.instance); err != nil {
w.logger.Warn(ctx, "Failed to resume worker in service", map[string]interface{}{
"instance": w.instance,
"error": err.Error(),
})
// Do not unpause if resume failed
w.updateDatabaseStatus(ctx)
return
}
w.logger.Info(ctx, "Worker resumed", map[string]interface{}{
"instance": w.instance,
})
w.logActivity(ctx, "INFO", fmt.Sprintf("Worker %s resumed", w.instance), nil, nil)
w.status.IsPaused = false
w.updateDatabaseStatus(ctx)
}
// Shutdown gracefully shuts down the worker and cleans up resources
func (w *Worker) Shutdown(ctx context.Context) error {
w.mu.Lock()
defer w.mu.Unlock()
w.logger.Info(ctx, "Worker starting shutdown", map[string]interface{}{
"instance": w.instance,
})
// Cancel the shutdown context to signal shutdown
if w.cancel != nil {
w.cancel()
}
// Wait for any active operations to complete
// This is a simple implementation - in a more complex system,
// you might want to track active operations more precisely
time.Sleep(config.WorkerSleepDuration)
// Clean up user failures map
w.failureMu.Lock()
w.userFailures = make(map[int]*UserFailureInfo)
w.failureMu.Unlock()
// Clear activity logs
w.activityLogs = make([]ActivityLog, 0)
w.logger.Info(ctx, "Worker shutdown completed", map[string]interface{}{
"instance": w.instance,
})
return nil
}
// updateDatabaseStatus updates the worker status in the database
func (w *Worker) updateDatabaseStatus(ctx context.Context) {
dbStatus := &models.WorkerStatus{
WorkerInstance: w.instance,
IsRunning: w.status.IsRunning,
IsPaused: w.status.IsPaused,
CurrentActivity: sql.NullString{String: w.status.CurrentActivity, Valid: w.status.CurrentActivity != ""},
LastHeartbeat: sql.NullTime{Time: time.Now(), Valid: true},
LastRunStart: sql.NullTime{Time: w.status.LastRunStart, Valid: !w.status.LastRunStart.IsZero()},
LastRunFinish: sql.NullTime{Time: w.status.LastRunFinish, Valid: !w.status.LastRunFinish.IsZero()},
LastRunError: sql.NullString{String: w.status.LastRunError, Valid: w.status.LastRunError != ""},
TotalQuestionsGenerated: w.getTotalQuestionsGenerated(),
TotalRuns: len(w.history),
}
if err := w.workerService.UpdateWorkerStatus(ctx, w.instance, dbStatus); err != nil {
w.logger.Error(ctx, "Failed to update worker status in database", err, map[string]interface{}{
"instance": w.instance,
})
}
}
// getTotalQuestionsGenerated calculates total questions generated from run history
func (w *Worker) getTotalQuestionsGenerated() int {
total := 0
for _, record := range w.history {
if record.Status == "Success" {
// Parse details to count questions - simplified for now
total++ // This would need to be enhanced to parse actual count
}
}
return total
}
func (w *Worker) generateNeededQuestions(ctx context.Context) (result0 string, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "generate_needed_questions",
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, &err)
// Check if globally paused BEFORE any work or logging
globalPaused, err := w.workerService.IsGlobalPaused(ctx)
if err != nil {
span.RecordError(err)
w.logger.Error(ctx, "Failed to check global pause status", err, map[string]interface{}{
"instance": w.instance,
})
return "Error checking global pause status", err
}
if globalPaused {
span.SetAttributes(attribute.Bool("globally_paused", true))
w.logger.Info(ctx, "Worker skipping question generation (globally paused)", map[string]interface{}{
"instance": w.instance,
})
return "Run paused globally", nil
}
aiUsers, err := w.getEligibleAIUsers(ctx)
if err != nil {
return "Error getting users", err
}
if len(aiUsers) == 0 {
w.logger.Info(ctx, "Worker: No active users with AI provider configuration found for question generation", map[string]interface{}{
"instance": w.instance,
})
return "No active users with AI provider configuration found", nil
}
var actions []string
var checkedUsers []string
var actuallyProcessedUsers []string
var hadAttemptedOperations bool
var hadFailures bool
for _, user := range aiUsers {
checkedUsers = append(checkedUsers, user.Username)
shouldProcess, skipReason := w.shouldProcessUser(ctx, &user)
if !shouldProcess {
if skipReason != "" {
w.logger.Info(ctx, "Worker user check", map[string]interface{}{
"instance": w.instance,
"username": user.Username,
"reason": skipReason,
})
}
continue
}
actuallyProcessedUsers = append(actuallyProcessedUsers, user.Username)
userActions, attempted, failed := w.processUserQuestionGeneration(ctx, &user)
if attempted {
hadAttemptedOperations = true
}
if failed {
hadFailures = true
}
if userActions != "" {
actions = append(actions, userActions)
}
w.logger.Info(ctx, "Worker completed check for user", map[string]interface{}{
"instance": w.instance,
"username": user.Username,
})
}
w.updateActivity("")
return w.summarizeRunActions(actions, checkedUsers, actuallyProcessedUsers, hadAttemptedOperations, hadFailures), nil
}
// getEligibleAIUsers returns users eligible for AI question generation
func (w *Worker) getEligibleAIUsers(ctx context.Context) (result0 []models.User, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_eligible_ai_users",
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, &err)
users, err := w.userService.GetAllUsers(ctx)
if err != nil {
span.RecordError(err)
return nil, err
}
var aiUsers []models.User
for _, user := range users {
if !user.AIEnabled.Valid || !user.AIEnabled.Bool {
continue
}
userPaused, err := w.workerService.IsUserPaused(ctx, user.ID)
if err == nil && userPaused {
continue
}
hasAIProvider := user.AIProvider.Valid && user.AIProvider.String != ""
hasAPIKey := false
if hasAIProvider {
savedKey, err := w.userService.GetUserAPIKey(ctx, user.ID, user.AIProvider.String)
if err == nil && savedKey != "" {
hasAPIKey = true
}
}
if hasAPIKey || hasAIProvider {
aiUsers = append(aiUsers, user)
}
}
return aiUsers, nil
}
// shouldProcessUser encapsulates exponential backoff and pause checks
func (w *Worker) shouldProcessUser(ctx context.Context, user *models.User) (bool, string) {
if !w.shouldRetryUser(user.ID) {
w.failureMu.RLock()
failure := w.userFailures[user.ID]
nextRetry := time.Until(failure.NextRetryTime)
w.failureMu.RUnlock()
return false, fmt.Sprintf("Skipping due to exponential backoff (failure #%d, retry in %v)", failure.ConsecutiveFailures, nextRetry.Round(time.Second))
}
globalPaused, err := w.workerService.IsGlobalPaused(ctx)
if err != nil {
return false, "Error checking global pause status"
}
if globalPaused {
return false, "Run paused globally"
}
status, err := w.workerService.GetWorkerStatus(ctx, w.instance)
if err == nil && status != nil && status.IsPaused {
return false, fmt.Sprintf("Worker instance %s paused", w.instance)
}
if ctx.Err() != nil {
return false, "Shutdown initiated"
}
return true, ""
}
// Helper: get the count of eligible questions for a user (excludes questions answered correctly in the last 2 days)
func (w *Worker) getEligibleQuestionCount(ctx context.Context, userID int, language, level string, qType models.QuestionType) (result0 int, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "get_eligible_question_count",
observability.AttributeUserID(userID),
attribute.String("language", language),
attribute.String("level", level),
attribute.String("question.type", string(qType)),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, &err)
// Safe user lookup: tests may not wire userService
userLookup := func(ctx context.Context, id int) (*models.User, error) {
// Only use the concrete UserService implementation to avoid invoking mocks in unit tests
if us, ok := w.userService.(*services.UserService); ok && us != nil {
return us.GetUserByID(ctx, id)
}
// No userService available or not concrete - return nil so helper falls back to UTC
return nil, nil
}
// Determine user-local 2-day window and pass UTC timestamps to query
startUTC, endUTC, _, err := contextutils.UserLocalDayRange(ctx, userID, 2, userLookup)
if err != nil {
return 0, contextutils.WrapError(err, "failed to compute user local day range")
}
query := `
SELECT COUNT(*)
FROM questions q
JOIN user_questions uq ON q.id = uq.question_id
WHERE uq.user_id = $1
AND q.language = $2
AND q.level = $3
AND q.type = $4
AND q.status = 'active'
AND NOT EXISTS (
SELECT 1 FROM user_responses ur
WHERE ur.user_id = $1
AND ur.question_id = q.id
AND ur.is_correct = TRUE
AND ur.created_at >= $5 AND ur.created_at < $6
)
`
// Try to get the database from the question service
var db *sql.DB
if qs, ok := w.questionService.(*services.QuestionService); ok {
db = qs.DB()
} else {
// For mock services or other implementations, we can't get the DB directly
// This is expected in unit tests
return 0, contextutils.ErrorWithContextf("cannot get database from question service implementation")
}
row := db.QueryRowContext(ctx, query, userID, language, level, qType, startUTC, endUTC)
var count int
if err := row.Scan(&count); err != nil {
return 0, err
}
return count, nil
}
func (w *Worker) processUserQuestionGeneration(ctx context.Context, user *models.User) (string, bool, bool) {
ctx, span := observability.TraceWorkerFunction(ctx, "processUserQuestionGeneration",
observability.AttributeUserID(user.ID),
attribute.String("user.username", user.Username),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
userLanguage := "italian"
if user.PreferredLanguage.Valid && user.PreferredLanguage.String != "" {
userLanguage = user.PreferredLanguage.String
span.SetAttributes(attribute.String("user.language", userLanguage))
}
userLevel := "A1"
if user.CurrentLevel.Valid && user.CurrentLevel.String != "" {
userLevel = user.CurrentLevel.String
span.SetAttributes(attribute.String("user.level", userLevel))
}
languages := []string{userLanguage}
levels := []string{userLevel}
questionTypes := []models.QuestionType{
models.Vocabulary,
models.FillInBlank,
models.QuestionAnswer,
models.ReadingComprehension,
}
// Reorder types based on active generation hints (hinted types first, stable order)
if w.hintService != nil {
if hints, err := w.hintService.GetActiveHintsForUser(ctx, user.ID); err == nil && len(hints) > 0 {
hinted := make([]models.QuestionType, 0, len(hints))
hintedSet := map[models.QuestionType]bool{}
for _, h := range hints {
qt := models.QuestionType(h.QuestionType)
hinted = append(hinted, qt)
hintedSet[qt] = true
}
rest := make([]models.QuestionType, 0, len(questionTypes))
for _, qt := range questionTypes {
if !hintedSet[qt] {
rest = append(rest, qt)
}
}
questionTypes = append(hinted, rest...)
}
}
var actions []string
var hadAttemptedOperations bool
var hadFailures bool
for _, language := range languages {
for _, level := range levels {
for _, qType := range questionTypes {
activity := fmt.Sprintf("Checking questions for user %s: %s %s %s", user.Username, language, level, qType)
w.updateActivity(activity)
// Use eligible question count (not just total assigned)
eligibleCount, err := w.getEligibleQuestionCount(ctx, user.ID, language, level, qType)
if err != nil {
span.RecordError(err)
hadFailures = true
continue // Continue to next question type
}
// If hinted, be more aggressive about generating for that type
hinted := false
if w.hintService != nil {
if hints, err := w.hintService.GetActiveHintsForUser(ctx, user.ID); err == nil {
for _, h := range hints {
if models.QuestionType(h.QuestionType) == qType {
hinted = true
break
}
}
}
}
refillThreshold := w.cfg.Server.QuestionRefillThreshold
if hinted {
// Treat as if pool is empty to trigger generation, but keep batch sizing logic
eligibleCount = 0
}
if eligibleCount < refillThreshold {
provider := "default"
if user.AIProvider.Valid && user.AIProvider.String != "" {
provider = user.AIProvider.String
}
// Base batch size from AI provider
needed := w.aiService.GetQuestionBatchSize(provider)
// Get user's learning preferences to use their personal FreshQuestionRatio
userPrefs, prefsErr := w.learningService.GetUserLearningPreferences(ctx, user.ID)
userFreshRatio := 0.7 // default fallback
if prefsErr == nil && userPrefs != nil && userPrefs.FreshQuestionRatio > 0 {
userFreshRatio = userPrefs.FreshQuestionRatio
} else if prefsErr != nil {
w.logger.Warn(ctx, "Failed to get user learning preferences, using default fresh ratio", map[string]interface{}{
"user_id": user.ID,
"error": prefsErr.Error(),
})
}
// Ensure at least enough fresh questions are available to meet the user's personal FreshQuestionRatio.
// This ensures daily question assignment can respect the user's freshness preference.
desiredFresh := int(math.Ceil(float64(refillThreshold) * userFreshRatio))
freshCandidates := 0
if qs, qerr := w.questionService.GetAdaptiveQuestionsForDaily(ctx, user.ID, language, level, 50); qerr == nil && qs != nil {
for _, q := range qs {
if q != nil && q.TotalResponses == 0 {
freshCandidates++
}
}
} else if qerr != nil {
// Log but don't fail - we'll conservatively proceed with base batch size
w.logger.Warn(ctx, "Failed to fetch adaptive questions for fresh-count check", map[string]interface{}{
"user_id": user.ID,
"error": qerr.Error(),
})
}
if missing := desiredFresh - freshCandidates; missing > 0 {
needed += missing
w.logger.Info(ctx, "Adjusting generation batch to meet user's personal fresh-question requirement", map[string]interface{}{
"user_id": user.ID,
"language": language,
"level": level,
"question_type": qType,
"user_fresh_ratio": userFreshRatio,
"base_batch_size": w.aiService.GetQuestionBatchSize(provider),
"desired_fresh": desiredFresh,
"fresh_candidates": freshCandidates,
"added_to_batch": missing,
"final_batch_size": needed,
})
}
hadAttemptedOperations = true
action, err := w.GenerateQuestionsForUser(ctx, user, language, level, qType, needed, "")
if err != nil {
hadFailures = true
// Continue to next question type instead of breaking all loops
continue
}
if action != "" {
actions = append(actions, action)
}
// Clear hint on successful generation attempt for this type
if hinted && w.hintService != nil {
_ = w.hintService.ClearHint(ctx, user.ID, language, level, qType)
}
}
}
}
}
return strings.Join(actions, "; "), hadAttemptedOperations, hadFailures
}
// summarizeRunActions builds the summary string for actions taken
func (w *Worker) summarizeRunActions(actions, checkedUsers, actuallyProcessedUsers []string, hadAttemptedOperations, hadFailures bool) string {
userList := "No users with AI configuration found"
if len(checkedUsers) > 0 {
userList = fmt.Sprintf("Checked users: %s", strings.Join(checkedUsers, ", "))
}
if len(actions) == 0 {
if len(actuallyProcessedUsers) == 0 {
return fmt.Sprintf("No actions taken. All users in exponential backoff. %s", userList)
}
if hadAttemptedOperations && hadFailures && len(actions) == 0 {
return fmt.Sprintf("No actions taken due to errors. %s", userList)
}
return fmt.Sprintf("No actions taken. All question types have sufficient questions. %s", userList)
}
userList = fmt.Sprintf("Processed users: %s", strings.Join(actuallyProcessedUsers, ", "))
// Format actions with line breaks for better readability in UI
if len(actions) == 1 {
return fmt.Sprintf("%s\n%s", actions[0], userList)
}
formattedActions := strings.Join(actions, "\n")
return fmt.Sprintf("%s\n%s", formattedActions, userList)
}
// GenerateQuestionsForUser generates questions for a specific user with the given parameters
func (w *Worker) GenerateQuestionsForUser(ctx context.Context, user *models.User, language, level string, qType models.QuestionType, count int, topic string) (result0 string, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "generate_questions_for_user",
observability.AttributeUserID(user.ID),
attribute.String("user.username", user.Username),
attribute.String("language", language),
attribute.String("level", level),
attribute.String("question.type", string(qType)),
attribute.Int("question.count", count),
attribute.String("topic", topic),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, &err)
if count <= 0 {
return "No questions needed", nil
}
// Gather priority data for variety selection
priorityData := w.getPriorityGenerationData(ctx, user.ID, language, level, qType)
var userWeakAreas []string
if priorityData != nil && priorityData.FocusOnWeakAreas {
userWeakAreas = priorityData.UserWeakAreas
}
var highPriorityTopics []string
if priorityData != nil {
highPriorityTopics = priorityData.HighPriorityTopics
}
var gapAnalysis map[string]int
if priorityData != nil {
gapAnalysis = priorityData.GapAnalysis
}
variety := w.aiService.VarietyService().SelectVarietyElements(ctx, level, highPriorityTopics, userWeakAreas, gapAnalysis)
// Log priority generation decisions
generationReasoning := w.getGenerationReasoning(priorityData, variety)
var freshQuestionRatio float64
if priorityData != nil {
freshQuestionRatio = priorityData.FreshQuestionRatio
}
priorityLog := PriorityGenerationLog{
UserID: user.ID,
Username: user.Username,
Language: language,
Level: level,
QuestionType: string(qType),
FocusOnWeakAreas: priorityData != nil && priorityData.FocusOnWeakAreas,
UserWeakAreas: userWeakAreas,
HighPriorityTopics: highPriorityTopics,
GapAnalysis: gapAnalysis,
FreshQuestionRatio: freshQuestionRatio,
SelectedVariety: variety,
GenerationReasoning: generationReasoning,
Timestamp: time.Now(),
}
w.logPriorityGeneration(ctx, priorityLog)
aiReq, recentQuestions, err := w.buildAIQuestionGenRequest(ctx, user, language, level, qType, count, topic)
if err != nil {
w.logger.Warn(ctx, "Worker failed to get recent questions", map[string]interface{}{
"instance": w.instance,
"error": err.Error(),
})
recentQuestions = []string{}
}
aiReq.RecentQuestionHistory = recentQuestions
userConfig := w.getUserAIConfig(ctx, user)
batchLogMsg := formatBatchLogMessage(user.Username, count, string(qType), language, level, variety, userConfig.Provider, userConfig.Model)
w.logger.Info(ctx, batchLogMsg, map[string]interface{}{
"instance": w.instance,
})
w.updateActivity(batchLogMsg)
w.logActivity(ctx, "INFO", batchLogMsg, &user.ID, &user.Username)
progressMsg, questions, errAI := w.handleAIQuestionStream(ctx, userConfig, aiReq, variety, count, language, level, qType, topic, user)
if errAI != nil {
w.recordUserFailure(ctx, user.ID, user.Username)
return progressMsg, errAI
}
if len(questions) == 0 {
w.recordUserFailure(ctx, user.ID, user.Username)
return progressMsg, contextutils.WrapErrorf(contextutils.ErrAIResponseInvalid, "AI service returned 0 questions for %s %s %s", language, level, qType)
}
savedCount := w.saveGeneratedQuestions(ctx, user, questions, language, level, qType, topic, variety)
if savedCount > 0 {
w.recordUserSuccess(ctx, user.ID, user.Username)
}
if savedCount != len(questions) {
w.recordUserFailure(ctx, user.ID, user.Username)
return fmt.Sprintf("Generated %d/%d %s questions for %s %s", savedCount, len(questions), qType, language, level),
contextutils.WrapErrorf(contextutils.ErrDatabaseQuery, "only saved %d out of %d generated questions", savedCount, len(questions))
}
return fmt.Sprintf("Generated %d %s questions for %s %s", savedCount, qType, language, level), nil
}
// buildAIQuestionGenRequest prepares the AI request and gets recent questions
func (w *Worker) buildAIQuestionGenRequest(ctx context.Context, user *models.User, language, level string, qType models.QuestionType, count int, _ string) (result0 *models.AIQuestionGenRequest, result1 []string, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "build_ai_question_gen_request",
observability.AttributeUserID(user.ID),
attribute.String("user.username", user.Username),
attribute.String("language", language),
attribute.String("level", level),
attribute.String("question.type", string(qType)),
attribute.Int("question.count", count),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, &err)
recentQuestions, err := w.questionService.GetRecentQuestionContentsForUser(ctx, user.ID, 10)
if err != nil {
span.RecordError(err)
return nil, nil, err
}
aiReq := &models.AIQuestionGenRequest{
Language: language,
Level: level,
QuestionType: qType,
Count: count,
}
aiReq.RecentQuestionHistory = recentQuestions
return aiReq, recentQuestions, nil
}
// getUserAIConfig builds the UserAIConfig struct with API key
func (w *Worker) getUserAIConfig(ctx context.Context, user *models.User) *services.UserAIConfig {
ctx, span := observability.TraceWorkerFunction(ctx, "get_user_ai_config",
observability.AttributeUserID(user.ID),
attribute.String("user.username", user.Username),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
provider := ""
if user.AIProvider.Valid {
provider = user.AIProvider.String
span.SetAttributes(attribute.String("ai.provider", provider))
}
model := ""
if user.AIModel.Valid {
model = user.AIModel.String
span.SetAttributes(attribute.String("ai.model", model))
}
apiKey := ""
if provider != "" {
savedKey, err := w.userService.GetUserAPIKey(ctx, user.ID, provider)
if err == nil && savedKey != "" {
apiKey = savedKey
}
}
return &services.UserAIConfig{
Provider: provider,
Model: model,
APIKey: apiKey,
Username: user.Username,
}
}
// handleAIQuestionStream handles the AI streaming and collects questions
func (w *Worker) handleAIQuestionStream(ctx context.Context, userConfig *services.UserAIConfig, req *models.AIQuestionGenRequest, variety *services.VarietyElements, count int, language, level string, qType models.QuestionType, topic string, user *models.User) (result0 string, result1 []*models.Question, err error) {
ctx, span := observability.TraceWorkerFunction(ctx, "handle_ai_question_stream",
attribute.String("ai.provider", userConfig.Provider),
attribute.String("ai.model", userConfig.Model),
attribute.String("language", language),
attribute.String("level", level),
attribute.String("question.type", string(qType)),
attribute.Int("question.count", count),
attribute.String("topic", topic),
attribute.String("user.username", user.Username),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, &err)
progressChan := make(chan *models.Question)
var questions []*models.Question
var wg sync.WaitGroup
var errAI error
progressMsg := ""
wg.Add(1)
go func() {
defer func() {
if r := recover(); r != nil {
w.logger.Error(ctx, "Panic in AI question stream goroutine", nil, map[string]interface{}{
"instance": w.instance,
"panic": fmt.Sprintf("%v", r),
})
}
wg.Done()
}()
errAI = w.aiService.GenerateQuestionsStream(ctx, userConfig, req, progressChan, variety)
}()
generatedCount := 0
for question := range progressChan {
generatedCount++
progressMsg = fmt.Sprintf("Generated %d/%d %s questions for %s %s", generatedCount, count, qType, language, level)
if topic != "" {
progressMsg = fmt.Sprintf("Generated %d/%d %s questions for %s %s (topic: %s)", generatedCount, count, qType, language, level, topic)
}
w.logger.Info(ctx, progressMsg, map[string]interface{}{
"instance": w.instance,
})
w.updateActivity(progressMsg)
w.logActivity(ctx, "INFO", progressMsg, &user.ID, &user.Username)
questions = append(questions, question)
}
wg.Wait()
return progressMsg, questions, errAI
}
// saveGeneratedQuestions saves questions to the DB and returns the count
func (w *Worker) saveGeneratedQuestions(ctx context.Context, user *models.User, questions []*models.Question, language, level string, qType models.QuestionType, topic string, variety *services.VarietyElements) int {
ctx, span := observability.TraceWorkerFunction(ctx, "save_generated_questions",
observability.AttributeUserID(user.ID),
attribute.String("user.username", user.Username),
attribute.String("language", language),
attribute.String("level", level),
attribute.String("question.type", string(qType)),
attribute.Int("question.count", len(questions)),
attribute.String("topic", topic),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
savingMsg := fmt.Sprintf("Saving %d new %s questions for %s %s", len(questions), qType, language, level)
if topic != "" {
savingMsg = fmt.Sprintf("Saving %d new %s questions for %s %s (topic: %s)", len(questions), qType, language, level, topic)
}
w.logger.Info(ctx, savingMsg, map[string]interface{}{
"instance": w.instance,
})
w.updateActivity(savingMsg)
w.logActivity(ctx, "INFO", savingMsg, &user.ID, &user.Username)
savedCount := 0
for _, q := range questions {
// Populate variety fields from the variety elements used during generation
if variety != nil {
q.TopicCategory = variety.TopicCategory
q.GrammarFocus = variety.GrammarFocus
q.VocabularyDomain = variety.VocabularyDomain
q.Scenario = variety.Scenario
q.StyleModifier = variety.StyleModifier
q.DifficultyModifier = variety.DifficultyModifier
q.TimeContext = variety.TimeContext
}
if err := w.questionService.SaveQuestion(ctx, q); err != nil {
w.logger.Error(ctx, "Failed to save generated question", err, map[string]interface{}{
"instance": w.instance,
"user_id": user.ID,
"language": language,
"level": level,
"question_type": qType,
})
} else {
// Assign the question to the user after saving
if err := w.questionService.AssignQuestionToUser(ctx, q.ID, user.ID); err != nil {
w.logger.Error(ctx, "Failed to assign question to user", err, map[string]interface{}{
"instance": w.instance,
"question_id": q.ID,
"user_id": user.ID,
})
} else {
savedCount++
}
}
}
if savedCount > 0 {
successMsg := fmt.Sprintf("Successfully saved %d new '%s' questions for %s %s", savedCount, qType, language, level)
if topic != "" {
successMsg = fmt.Sprintf("Successfully saved %d new '%s' questions for %s %s (topic: %s)", savedCount, qType, language, level, topic)
}
w.logActivity(ctx, "INFO", successMsg, &user.ID, &user.Username)
}
return savedCount
}
func (w *Worker) updateActivity(activity string) {
w.mu.Lock()
defer w.mu.Unlock()
w.status.CurrentActivity = activity
}
// logActivity adds an activity log entry
func (w *Worker) logActivity(_ context.Context, _, message string, userID *int, username *string) {
w.mu.Lock()
defer w.mu.Unlock()
logEntry := ActivityLog{
Timestamp: time.Now(),
Level: "INFO",
Message: message,
UserID: userID,
Username: username,
}
// Add to activity logs (circular buffer)
w.activityLogs = append(w.activityLogs, logEntry)
// Keep only the last maxActivityLogs entries
if len(w.activityLogs) > w.cfg.Server.MaxActivityLogs {
w.activityLogs = w.activityLogs[len(w.activityLogs)-w.cfg.Server.MaxActivityLogs:]
}
}
// shouldRetryUser checks if enough time has passed since the last failure for exponential backoff
func (w *Worker) shouldRetryUser(userID int) bool {
w.failureMu.RLock()
defer w.failureMu.RUnlock()
failure, exists := w.userFailures[userID]
if !exists {
return true // No previous failures, go ahead
}
return time.Now().After(failure.NextRetryTime)
}
// recordUserFailure records a failure and calculates the next retry time with exponential backoff
func (w *Worker) recordUserFailure(ctx context.Context, userID int, username string) {
ctx, span := observability.TraceWorkerFunction(ctx, "record_user_failure",
observability.AttributeUserID(userID),
attribute.String("user.username", username),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
w.failureMu.Lock()
defer w.failureMu.Unlock()
failure, exists := w.userFailures[userID]
if !exists {
failure = &UserFailureInfo{}
w.userFailures[userID] = failure
}
failure.ConsecutiveFailures++
failure.LastFailureTime = time.Now()
// Exponential backoff: 2^failures seconds, max 1 hour
backoffSeconds := int(math.Pow(2, float64(failure.ConsecutiveFailures)))
if backoffSeconds > 3600 {
backoffSeconds = 3600
}
failure.NextRetryTime = time.Now().Add(time.Duration(backoffSeconds) * time.Second)
span.SetAttributes(
attribute.Int("failure.count", failure.ConsecutiveFailures),
attribute.Int("backoff.seconds", backoffSeconds),
)
w.logger.Info(ctx, "Worker recorded user failure", map[string]interface{}{
"instance": w.instance,
"username": username,
"failure_count": failure.ConsecutiveFailures,
"next_retry_seconds": backoffSeconds,
})
}
// recordUserSuccess clears the failure count for a user
func (w *Worker) recordUserSuccess(ctx context.Context, userID int, username string) {
ctx, span := observability.TraceWorkerFunction(ctx, "record_user_success",
observability.AttributeUserID(userID),
attribute.String("user.username", username),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
w.failureMu.Lock()
defer w.failureMu.Unlock()
failure, exists := w.userFailures[userID]
if exists && failure.ConsecutiveFailures > 0 {
span.SetAttributes(attribute.Int("previous_failures", failure.ConsecutiveFailures))
w.logger.Info(ctx, "Worker user success after failures, resetting backoff", map[string]interface{}{
"instance": w.instance,
"username": username,
"previous_failures": failure.ConsecutiveFailures,
})
delete(w.userFailures, userID)
}
}
// formatBatchLogMessage creates a formatted log message for batch question generation
func formatBatchLogMessage(username string, count int, qType, language, level string, variety *services.VarietyElements, provider, model string) string {
var summaryFields []string
if variety != nil {
if variety.GrammarFocus != "" {
summaryFields = append(summaryFields, "grammar: "+variety.GrammarFocus)
}
if variety.TopicCategory != "" {
summaryFields = append(summaryFields, "topic: "+variety.TopicCategory)
}
if variety.Scenario != "" {
summaryFields = append(summaryFields, "scenario: "+variety.Scenario)
}
if variety.StyleModifier != "" {
summaryFields = append(summaryFields, "style: "+variety.StyleModifier)
}
if variety.DifficultyModifier != "" {
summaryFields = append(summaryFields, "difficulty: "+variety.DifficultyModifier)
}
if variety.VocabularyDomain != "" {
summaryFields = append(summaryFields, "vocab: "+variety.VocabularyDomain)
}
if variety.TimeContext != "" {
summaryFields = append(summaryFields, "time: "+variety.TimeContext)
}
}
providerModel := "provider: " + provider + ", model: " + model
if len(summaryFields) > 0 {
summaryFields = append(summaryFields, providerModel)
} else {
summaryFields = []string{providerModel}
}
return fmt.Sprintf("Worker [user=%s]: Batch %d %s questions (lang: %s, level: %s) | %s", username, count, qType, language, level, strings.Join(summaryFields, " | "))
}
// PriorityGenerationData contains priority information to guide AI question generation
type PriorityGenerationData struct {
UserWeakAreas []string `json:"user_weak_areas,omitempty"`
HighPriorityTopics []string `json:"high_priority_topics,omitempty"`
GapAnalysis map[string]int `json:"gap_analysis,omitempty"`
UserPreferences *models.UserLearningPreferences `json:"user_preferences,omitempty"`
PriorityDistribution map[string]int `json:"priority_distribution,omitempty"`
FocusOnWeakAreas bool `json:"focus_on_weak_areas"`
FreshQuestionRatio float64 `json:"fresh_question_ratio"`
}
// PriorityGenerationLog contains structured data about priority-aware generation decisions
type PriorityGenerationLog struct {
UserID int `json:"user_id"`
Username string `json:"username"`
Language string `json:"language"`
Level string `json:"level"`
QuestionType string `json:"question_type"`
FocusOnWeakAreas bool `json:"focus_on_weak_areas"`
UserWeakAreas []string `json:"user_weak_areas,omitempty"`
HighPriorityTopics []string `json:"high_priority_topics,omitempty"`
GapAnalysis map[string]int `json:"gap_analysis,omitempty"`
FreshQuestionRatio float64 `json:"fresh_question_ratio"`
SelectedVariety *services.VarietyElements `json:"selected_variety"`
GenerationReasoning string `json:"generation_reasoning"`
Timestamp time.Time `json:"timestamp"`
}
// logPriorityGeneration logs priority generation data as JSON
func (w *Worker) logPriorityGeneration(ctx context.Context, priorityLog PriorityGenerationLog) {
ctx, span := observability.TraceWorkerFunction(ctx, "log_priority_generation",
observability.AttributeUserID(priorityLog.UserID),
attribute.String("user.username", priorityLog.Username),
attribute.String("language", priorityLog.Language),
attribute.String("level", priorityLog.Level),
attribute.String("question.type", priorityLog.QuestionType),
attribute.String("worker.instance", w.instance),
)
defer observability.FinishSpan(span, nil)
logJSON, err := json.Marshal(priorityLog)
if err != nil {
span.RecordError(err)
w.logger.Error(ctx, "Failed to marshal priority generation log", err, map[string]interface{}{
"instance": w.instance,
})
return
}
w.logger.Info(ctx, "Worker priority generation", map[string]interface{}{
"instance": w.instance,
"data": string(logJSON),
})
}
// getGenerationReasoning provides a human-readable explanation of the generation strategy
func (w *Worker) getGenerationReasoning(priorityData *PriorityGenerationData, variety *services.VarietyElements) string {
if priorityData == nil {
return "standard generation"
}
var reasons []string
if priorityData.FocusOnWeakAreas && len(priorityData.UserWeakAreas) > 0 {
reasons = append(reasons, fmt.Sprintf("focusing on weak areas: %s", strings.Join(priorityData.UserWeakAreas, ", ")))
}
if len(priorityData.HighPriorityTopics) > 0 {
reasons = append(reasons, fmt.Sprintf("high priority topics: %s", strings.Join(priorityData.HighPriorityTopics, ", ")))
}
if len(priorityData.GapAnalysis) > 0 {
var gaps []string
for topic, count := range priorityData.GapAnalysis {
gaps = append(gaps, fmt.Sprintf("%s(%d)", topic, count))
}
reasons = append(reasons, fmt.Sprintf("gap analysis: %s", strings.Join(gaps, ", ")))
}
if priorityData.FreshQuestionRatio > 0 {
reasons = append(reasons, fmt.Sprintf("fresh ratio: %.1f%%", priorityData.FreshQuestionRatio*100))
}
if variety != nil {
var varietyElements []string
if variety.TopicCategory != "" {
varietyElements = append(varietyElements, fmt.Sprintf("topic:%s", variety.TopicCategory))
}
if variety.GrammarFocus != "" {
varietyElements = append(varietyElements, fmt.Sprintf("grammar:%s", variety.GrammarFocus))
}
if variety.VocabularyDomain != "" {
varietyElements = append(varietyElements, fmt.Sprintf("vocab:%s", variety.VocabularyDomain))
}
if variety.Scenario != "" {
varietyElements = append(varietyElements, fmt.Sprintf("scenario:%s", variety.Scenario))
}
if len(varietyElements) > 0 {
reasons = append(reasons, fmt.Sprintf("variety: %s", strings.Join(varietyElements, ", ")))
}
}
if len(reasons) == 0 {
return "standard generation"
}
return strings.Join(reasons, "; ")
}
// getPriorityGenerationData gathers priority data for AI question generation
func (w *Worker) getPriorityGenerationData(ctx context.Context, userID int, language, level string, questionType models.QuestionType) *PriorityGenerationData {
// Get user preferences
prefs, err := w.learningService.GetUserLearningPreferences(ctx, userID)
if err != nil {
w.logger.Warn(ctx, "Worker failed to get user preferences", map[string]interface{}{
"instance": w.instance,
"user_id": userID,
"error": err.Error(),
})
prefs = w.getDefaultLearningPreferences()
}
// Get weak areas
weakAreas, err := w.learningService.GetUserWeakAreas(ctx, userID, 5)
if err != nil {
w.logger.Warn(ctx, "Worker failed to get weak areas", map[string]interface{}{
"instance": w.instance,
"user_id": userID,
"error": err.Error(),
})
weakAreas = []map[string]interface{}{}
}
// Convert weak areas to topic strings
var weakAreaTopics []string
for _, area := range weakAreas {
if topic, ok := area["topic"].(string); ok && topic != "" {
weakAreaTopics = append(weakAreaTopics, topic)
}
}
// Get high priority topics
highPriorityTopics, err := w.getHighPriorityTopics(ctx, userID, language, level, questionType)
if err != nil {
w.logger.Warn(ctx, "Worker failed to get high priority topics", map[string]interface{}{
"instance": w.instance,
"user_id": userID,
"error": err.Error(),
})
highPriorityTopics = []string{}
}
// Get gap analysis
gapAnalysis, err := w.getGapAnalysis(ctx, userID, language, level, questionType)
if err != nil {
w.logger.Warn(ctx, "Worker failed to get gap analysis", map[string]interface{}{
"instance": w.instance,
"user_id": userID,
"error": err.Error(),
})
gapAnalysis = map[string]int{}
}
// Get priority distribution
priorityDistribution, err := w.getPriorityDistribution(ctx, userID, language, level, questionType)
if err != nil {
w.logger.Warn(ctx, "Worker failed to get priority distribution", map[string]interface{}{
"instance": w.instance,
"user_id": userID,
"error": err.Error(),
})
priorityDistribution = map[string]int{}
}
// Determine if we should focus on weak areas
focusOnWeakAreas := len(weakAreaTopics) > 0 && prefs != nil && prefs.FocusOnWeakAreas
return &PriorityGenerationData{
UserWeakAreas: weakAreaTopics,
HighPriorityTopics: highPriorityTopics,
GapAnalysis: gapAnalysis,
UserPreferences: prefs,
PriorityDistribution: priorityDistribution,
FocusOnWeakAreas: focusOnWeakAreas,
FreshQuestionRatio: prefs.FreshQuestionRatio,
}
}
// getDefaultLearningPreferences returns default learning preferences
func (w *Worker) getDefaultLearningPreferences() *models.UserLearningPreferences {
return &models.UserLearningPreferences{
FocusOnWeakAreas: false,
FreshQuestionRatio: 0.3,
WeakAreaBoost: 1.5,
}
}
// getHighPriorityTopics returns topics that have high average priority scores
func (w *Worker) getHighPriorityTopics(ctx context.Context, userID int, language, level string, questionType models.QuestionType) (result0 []string, err error) {
return w.workerService.GetHighPriorityTopics(ctx, userID, language, level, string(questionType))
}
// getGapAnalysis identifies areas with insufficient questions available
func (w *Worker) getGapAnalysis(ctx context.Context, userID int, language, level string, questionType models.QuestionType) (result0 map[string]int, err error) {
return w.workerService.GetGapAnalysis(ctx, userID, language, level, string(questionType))
}
// getPriorityDistribution returns the distribution of priority scores
func (w *Worker) getPriorityDistribution(ctx context.Context, userID int, language, level string, questionType models.QuestionType) (result0 map[string]int, err error) {
return w.workerService.GetPriorityDistribution(ctx, userID, language, level, string(questionType))
}